Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions flaxdiff/data/sources/tfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,12 @@ def tfds_augmenters(image_scale, method):
else:
interpolation = cv2.INTER_AREA

augments = augmax.Chain(
augmax.HorizontalFlip(0.5),
augmax.RandomContrast((-0.05, 0.05), 1.),
augmax.RandomBrightness((-0.2, 0.2), 1.)
)
from torchvision.transforms import v2

augments = jax.jit(augments, backend="cpu")
augments = v2.Compose([
v2.RandomHorizontalFlip(p=0.5),
v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
])

class augmenters(pygrain.MapTransform):
def __init__(self, *args, **kwargs):
Expand All @@ -67,8 +66,9 @@ def map(self, element) -> Dict[str, jnp.array]:
image = element['image']
image = cv2.resize(image, (image_scale, image_scale),
interpolation=interpolation)
# image = augments(image)
image = augments(image)
# image = (image - 127.5) / 127.5

caption = labelizer(element)
results = self.tokenize(caption)
return {
Expand Down
38 changes: 22 additions & 16 deletions flaxdiff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class EfficientAttention(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = kernel_init(1.0)
# kernel_init: Callable = kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
Expand All @@ -34,15 +34,21 @@ def setup(self):
self.heads * self.dim_head,
precision=self.precision,
use_bias=self.use_bias,
kernel_init=self.kernel_init,
# kernel_init=self.kernel_init,
dtype=self.dtype
)
self.query = dense(name="to_q")
self.key = dense(name="to_k")
self.value = dense(name="to_v")

self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
self.proj_attn = nn.DenseGeneral(
self.query_dim,
use_bias=False,
precision=self.precision,
# kernel_init=self.kernel_init,
dtype=self.dtype,
name="to_out_0"
)
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)

def _reshape_tensor_to_head_dim(self, tensor):
Expand Down Expand Up @@ -115,7 +121,7 @@ class NormalAttention(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = kernel_init(1.0)
# kernel_init: Callable = kernel_init(1.0)
force_fp32_for_softmax: bool = True

def setup(self):
Expand All @@ -126,7 +132,7 @@ def setup(self):
axis=-1,
precision=self.precision,
use_bias=self.use_bias,
kernel_init=self.kernel_init,
# kernel_init=self.kernel_init,
dtype=self.dtype
)
self.query = dense(name="to_q")
Expand All @@ -140,7 +146,7 @@ def setup(self):
use_bias=self.use_bias,
dtype=self.dtype,
name="to_out_0",
kernel_init=self.kernel_init
# kernel_init=self.kernel_init
# kernel_init=jax.nn.initializers.xavier_uniform()
)

Expand Down Expand Up @@ -236,7 +242,7 @@ class BasicTransformerBlock(nn.Module):
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
use_bias: bool = True
kernel_init: Callable = kernel_init(1.0)
# kernel_init: Callable = kernel_init(1.0)
use_flash_attention:bool = False
use_cross_only:bool = False
only_pure_attention:bool = False
Expand All @@ -256,7 +262,7 @@ def setup(self):
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init,
# kernel_init=self.kernel_init,
force_fp32_for_softmax=self.force_fp32_for_softmax
)
self.attention2 = attenBlock(
Expand All @@ -267,7 +273,7 @@ def setup(self):
precision=self.precision,
use_bias=self.use_bias,
dtype=self.dtype,
kernel_init=self.kernel_init,
# kernel_init=self.kernel_init,
force_fp32_for_softmax=self.force_fp32_for_softmax
)

Expand Down Expand Up @@ -303,7 +309,7 @@ class TransformerBlock(nn.Module):
use_self_and_cross:bool = True
only_pure_attention:bool = False
force_fp32_for_softmax: bool = True
kernel_init: Callable = kernel_init(1.0)
# kernel_init: Callable = kernel_init(1.0)
norm_inputs: bool = True
explicitly_add_residual: bool = True

Expand All @@ -317,12 +323,12 @@ def __call__(self, x, context=None):
if self.use_linear_attention:
projected_x = nn.Dense(features=inner_dim,
use_bias=False, precision=self.precision,
kernel_init=self.kernel_init,
# kernel_init=self.kernel_init,
dtype=self.dtype, name=f'project_in')(x)
else:
projected_x = nn.Conv(
features=inner_dim, kernel_size=(1, 1),
kernel_init=self.kernel_init,
# kernel_init=self.kernel_init,
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_in_conv',
)(x)
Expand All @@ -344,19 +350,19 @@ def __call__(self, x, context=None):
use_cross_only=(not self.use_self_and_cross),
only_pure_attention=self.only_pure_attention,
force_fp32_for_softmax=self.force_fp32_for_softmax,
kernel_init=self.kernel_init
# kernel_init=self.kernel_init
)(projected_x, context)

if self.use_projection == True:
if self.use_linear_attention:
projected_x = nn.Dense(features=C, precision=self.precision,
dtype=self.dtype, use_bias=False,
kernel_init=self.kernel_init,
# kernel_init=self.kernel_init,
name=f'project_out')(projected_x)
else:
projected_x = nn.Conv(
features=C, kernel_size=(1, 1),
kernel_init=self.kernel_init,
# kernel_init=self.kernel_i nit,
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
precision=self.precision, name=f'project_out_conv',
)(projected_x)
Expand Down
26 changes: 8 additions & 18 deletions flaxdiff/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,16 @@ def __call__(self, x):
class TimeProjection(nn.Module):
features:int
activation:Callable=jax.nn.gelu
kernel_init:Callable=kernel_init(1.0)

@nn.compact
def __call__(self, x):
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
x = nn.DenseGeneral(
self.features,
)(x)
x = self.activation(x)
x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
x = nn.DenseGeneral(
self.features,
)(x)
x = self.activation(x)
return x

Expand All @@ -123,7 +126,6 @@ class SeparableConv(nn.Module):
kernel_size:tuple=(3, 3)
strides:tuple=(1, 1)
use_bias:bool=False
kernel_init:Callable=kernel_init(1.0)
padding:str="SAME"
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
Expand All @@ -133,15 +135,15 @@ def __call__(self, x):
in_features = x.shape[-1]
depthwise = nn.Conv(
features=in_features, kernel_size=self.kernel_size,
strides=self.strides, kernel_init=self.kernel_init,
strides=self.strides,
feature_group_count=in_features, use_bias=self.use_bias,
padding=self.padding,
dtype=self.dtype,
precision=self.precision
)(x)
pointwise = nn.Conv(
features=self.features, kernel_size=(1, 1),
strides=(1, 1), kernel_init=self.kernel_init,
strides=(1, 1),
use_bias=self.use_bias,
dtype=self.dtype,
precision=self.precision
Expand All @@ -153,7 +155,6 @@ class ConvLayer(nn.Module):
features:int
kernel_size:tuple=(3, 3)
strides:tuple=(1, 1)
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None

Expand All @@ -164,7 +165,6 @@ def setup(self):
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
dtype=self.dtype,
precision=self.precision
)
Expand All @@ -183,7 +183,6 @@ def setup(self):
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
dtype=self.dtype,
precision=self.precision
)
Expand All @@ -192,7 +191,6 @@ def setup(self):
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
dtype=self.dtype,
precision=self.precision
)
Expand All @@ -206,7 +204,6 @@ class Upsample(nn.Module):
activation:Callable=jax.nn.swish
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
kernel_init:Callable=kernel_init(1.0)

@nn.compact
def __call__(self, x, residual=None):
Expand All @@ -221,7 +218,6 @@ def __call__(self, x, residual=None):
strides=(1, 1),
dtype=self.dtype,
precision=self.precision,
kernel_init=self.kernel_init
)(out)
if residual is not None:
out = jnp.concatenate([out, residual], axis=-1)
Expand All @@ -233,7 +229,6 @@ class Downsample(nn.Module):
activation:Callable=jax.nn.swish
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
kernel_init:Callable=kernel_init(1.0)

@nn.compact
def __call__(self, x, residual=None):
Expand All @@ -244,7 +239,6 @@ def __call__(self, x, residual=None):
strides=(2, 2),
dtype=self.dtype,
precision=self.precision,
kernel_init=self.kernel_init
)(x)
if residual is not None:
if residual.shape[1] > out.shape[1]:
Expand All @@ -269,7 +263,6 @@ class ResidualBlock(nn.Module):
direction:str=None
res:int=2
norm_groups:int=8
kernel_init:Callable=kernel_init(1.0)
dtype: Optional[Dtype] = None
precision: PrecisionLike = None
named_norms:bool=False
Expand All @@ -296,7 +289,6 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
name="conv1",
dtype=self.dtype,
precision=self.precision
Expand All @@ -321,7 +313,6 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
features=self.features,
kernel_size=self.kernel_size,
strides=self.strides,
kernel_init=self.kernel_init,
name="conv2",
dtype=self.dtype,
precision=self.precision
Expand All @@ -333,7 +324,6 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe
features=self.features,
kernel_size=(1, 1),
strides=1,
kernel_init=self.kernel_init,
name="residual_conv",
dtype=self.dtype,
precision=self.precision
Expand Down
Loading