diff --git a/flaxdiff/data/sources/tfds.py b/flaxdiff/data/sources/tfds.py index 65f480f..f29e75e 100644 --- a/flaxdiff/data/sources/tfds.py +++ b/flaxdiff/data/sources/tfds.py @@ -50,13 +50,12 @@ def tfds_augmenters(image_scale, method): else: interpolation = cv2.INTER_AREA - augments = augmax.Chain( - augmax.HorizontalFlip(0.5), - augmax.RandomContrast((-0.05, 0.05), 1.), - augmax.RandomBrightness((-0.2, 0.2), 1.) - ) + from torchvision.transforms import v2 - augments = jax.jit(augments, backend="cpu") + augments = v2.Compose([ + v2.RandomHorizontalFlip(p=0.5), + v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2) + ]) class augmenters(pygrain.MapTransform): def __init__(self, *args, **kwargs): @@ -67,8 +66,9 @@ def map(self, element) -> Dict[str, jnp.array]: image = element['image'] image = cv2.resize(image, (image_scale, image_scale), interpolation=interpolation) - # image = augments(image) + image = augments(image) # image = (image - 127.5) / 127.5 + caption = labelizer(element) results = self.tokenize(caption) return { diff --git a/flaxdiff/models/attention.py b/flaxdiff/models/attention.py index d90eafc..6776245 100644 --- a/flaxdiff/models/attention.py +++ b/flaxdiff/models/attention.py @@ -23,7 +23,7 @@ class EfficientAttention(nn.Module): dtype: Optional[Dtype] = None precision: PrecisionLike = None use_bias: bool = True - kernel_init: Callable = kernel_init(1.0) + # kernel_init: Callable = kernel_init(1.0) force_fp32_for_softmax: bool = True def setup(self): @@ -34,15 +34,21 @@ def setup(self): self.heads * self.dim_head, precision=self.precision, use_bias=self.use_bias, - kernel_init=self.kernel_init, + # kernel_init=self.kernel_init, dtype=self.dtype ) self.query = dense(name="to_q") self.key = dense(name="to_k") self.value = dense(name="to_v") - self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision, - kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0") + self.proj_attn = nn.DenseGeneral( + self.query_dim, + use_bias=False, + precision=self.precision, + # kernel_init=self.kernel_init, + dtype=self.dtype, + name="to_out_0" + ) # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16) def _reshape_tensor_to_head_dim(self, tensor): @@ -115,7 +121,7 @@ class NormalAttention(nn.Module): dtype: Optional[Dtype] = None precision: PrecisionLike = None use_bias: bool = True - kernel_init: Callable = kernel_init(1.0) + # kernel_init: Callable = kernel_init(1.0) force_fp32_for_softmax: bool = True def setup(self): @@ -126,7 +132,7 @@ def setup(self): axis=-1, precision=self.precision, use_bias=self.use_bias, - kernel_init=self.kernel_init, + # kernel_init=self.kernel_init, dtype=self.dtype ) self.query = dense(name="to_q") @@ -140,7 +146,7 @@ def setup(self): use_bias=self.use_bias, dtype=self.dtype, name="to_out_0", - kernel_init=self.kernel_init + # kernel_init=self.kernel_init # kernel_init=jax.nn.initializers.xavier_uniform() ) @@ -236,7 +242,7 @@ class BasicTransformerBlock(nn.Module): dtype: Optional[Dtype] = None precision: PrecisionLike = None use_bias: bool = True - kernel_init: Callable = kernel_init(1.0) + # kernel_init: Callable = kernel_init(1.0) use_flash_attention:bool = False use_cross_only:bool = False only_pure_attention:bool = False @@ -256,7 +262,7 @@ def setup(self): precision=self.precision, use_bias=self.use_bias, dtype=self.dtype, - kernel_init=self.kernel_init, + # kernel_init=self.kernel_init, force_fp32_for_softmax=self.force_fp32_for_softmax ) self.attention2 = attenBlock( @@ -267,7 +273,7 @@ def setup(self): precision=self.precision, use_bias=self.use_bias, dtype=self.dtype, - kernel_init=self.kernel_init, + # kernel_init=self.kernel_init, force_fp32_for_softmax=self.force_fp32_for_softmax ) @@ -303,7 +309,7 @@ class TransformerBlock(nn.Module): use_self_and_cross:bool = True only_pure_attention:bool = False force_fp32_for_softmax: bool = True - kernel_init: Callable = kernel_init(1.0) + # kernel_init: Callable = kernel_init(1.0) norm_inputs: bool = True explicitly_add_residual: bool = True @@ -317,12 +323,12 @@ def __call__(self, x, context=None): if self.use_linear_attention: projected_x = nn.Dense(features=inner_dim, use_bias=False, precision=self.precision, - kernel_init=self.kernel_init, + # kernel_init=self.kernel_init, dtype=self.dtype, name=f'project_in')(x) else: projected_x = nn.Conv( features=inner_dim, kernel_size=(1, 1), - kernel_init=self.kernel_init, + # kernel_init=self.kernel_init, strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, precision=self.precision, name=f'project_in_conv', )(x) @@ -344,19 +350,19 @@ def __call__(self, x, context=None): use_cross_only=(not self.use_self_and_cross), only_pure_attention=self.only_pure_attention, force_fp32_for_softmax=self.force_fp32_for_softmax, - kernel_init=self.kernel_init + # kernel_init=self.kernel_init )(projected_x, context) if self.use_projection == True: if self.use_linear_attention: projected_x = nn.Dense(features=C, precision=self.precision, dtype=self.dtype, use_bias=False, - kernel_init=self.kernel_init, + # kernel_init=self.kernel_init, name=f'project_out')(projected_x) else: projected_x = nn.Conv( features=C, kernel_size=(1, 1), - kernel_init=self.kernel_init, + # kernel_init=self.kernel_i nit, strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype, precision=self.precision, name=f'project_out_conv', )(projected_x) diff --git a/flaxdiff/models/common.py b/flaxdiff/models/common.py index d91eba7..c04f1d6 100644 --- a/flaxdiff/models/common.py +++ b/flaxdiff/models/common.py @@ -108,13 +108,16 @@ def __call__(self, x): class TimeProjection(nn.Module): features:int activation:Callable=jax.nn.gelu - kernel_init:Callable=kernel_init(1.0) @nn.compact def __call__(self, x): - x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x) + x = nn.DenseGeneral( + self.features, + )(x) x = self.activation(x) - x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x) + x = nn.DenseGeneral( + self.features, + )(x) x = self.activation(x) return x @@ -123,7 +126,6 @@ class SeparableConv(nn.Module): kernel_size:tuple=(3, 3) strides:tuple=(1, 1) use_bias:bool=False - kernel_init:Callable=kernel_init(1.0) padding:str="SAME" dtype: Optional[Dtype] = None precision: PrecisionLike = None @@ -133,7 +135,7 @@ def __call__(self, x): in_features = x.shape[-1] depthwise = nn.Conv( features=in_features, kernel_size=self.kernel_size, - strides=self.strides, kernel_init=self.kernel_init, + strides=self.strides, feature_group_count=in_features, use_bias=self.use_bias, padding=self.padding, dtype=self.dtype, @@ -141,7 +143,7 @@ def __call__(self, x): )(x) pointwise = nn.Conv( features=self.features, kernel_size=(1, 1), - strides=(1, 1), kernel_init=self.kernel_init, + strides=(1, 1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision @@ -153,7 +155,6 @@ class ConvLayer(nn.Module): features:int kernel_size:tuple=(3, 3) strides:tuple=(1, 1) - kernel_init:Callable=kernel_init(1.0) dtype: Optional[Dtype] = None precision: PrecisionLike = None @@ -164,7 +165,6 @@ def setup(self): features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, dtype=self.dtype, precision=self.precision ) @@ -183,7 +183,6 @@ def setup(self): features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, dtype=self.dtype, precision=self.precision ) @@ -192,7 +191,6 @@ def setup(self): features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, dtype=self.dtype, precision=self.precision ) @@ -206,7 +204,6 @@ class Upsample(nn.Module): activation:Callable=jax.nn.swish dtype: Optional[Dtype] = None precision: PrecisionLike = None - kernel_init:Callable=kernel_init(1.0) @nn.compact def __call__(self, x, residual=None): @@ -221,7 +218,6 @@ def __call__(self, x, residual=None): strides=(1, 1), dtype=self.dtype, precision=self.precision, - kernel_init=self.kernel_init )(out) if residual is not None: out = jnp.concatenate([out, residual], axis=-1) @@ -233,7 +229,6 @@ class Downsample(nn.Module): activation:Callable=jax.nn.swish dtype: Optional[Dtype] = None precision: PrecisionLike = None - kernel_init:Callable=kernel_init(1.0) @nn.compact def __call__(self, x, residual=None): @@ -244,7 +239,6 @@ def __call__(self, x, residual=None): strides=(2, 2), dtype=self.dtype, precision=self.precision, - kernel_init=self.kernel_init )(x) if residual is not None: if residual.shape[1] > out.shape[1]: @@ -269,7 +263,6 @@ class ResidualBlock(nn.Module): direction:str=None res:int=2 norm_groups:int=8 - kernel_init:Callable=kernel_init(1.0) dtype: Optional[Dtype] = None precision: PrecisionLike = None named_norms:bool=False @@ -296,7 +289,6 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, name="conv1", dtype=self.dtype, precision=self.precision @@ -321,7 +313,6 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe features=self.features, kernel_size=self.kernel_size, strides=self.strides, - kernel_init=self.kernel_init, name="conv2", dtype=self.dtype, precision=self.precision @@ -333,7 +324,6 @@ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_fe features=self.features, kernel_size=(1, 1), strides=1, - kernel_init=self.kernel_init, name="residual_conv", dtype=self.dtype, precision=self.precision diff --git a/flaxdiff/models/simple_unet.py b/flaxdiff/models/simple_unet.py index 70692ab..06693cf 100644 --- a/flaxdiff/models/simple_unet.py +++ b/flaxdiff/models/simple_unet.py @@ -20,7 +20,6 @@ class Unet(nn.Module): dtype: Optional[Dtype] = None precision: PrecisionLike = None named_norms: bool = False # This is for backward compatibility reasons; older checkpoints have named norms - kernel_init: Callable = partial(kernel_init, dtype=jnp.float32) def setup(self): if self.norm_groups > 0: @@ -50,7 +49,6 @@ def __call__(self, x, temb, textcontext): features=self.feature_depths[0], kernel_size=(3, 3), strides=(1, 1), - kernel_init=self.kernel_init(scale=1.0), dtype=self.dtype, precision=self.precision )(x) @@ -65,7 +63,6 @@ def __call__(self, x, temb, textcontext): down_conv_type, name=f"down_{i}_residual_{j}", features=dim_in, - kernel_init=self.kernel_init(scale=1.0), kernel_size=(3, 3), strides=(1, 1), activation=self.activation, @@ -85,7 +82,6 @@ def __call__(self, x, temb, textcontext): force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False), norm_inputs=attention_config.get("norm_inputs", True), explicitly_add_residual=attention_config.get("explicitly_add_residual", True), - kernel_init=self.kernel_init(scale=1.0), name=f"down_{i}_attention_{j}")(x, textcontext) # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in) downs.append(x) @@ -108,7 +104,6 @@ def __call__(self, x, temb, textcontext): middle_conv_type, name=f"middle_res1_{j}", features=middle_dim_out, - kernel_init=self.kernel_init(scale=1.0), kernel_size=(3, 3), strides=(1, 1), activation=self.activation, @@ -129,13 +124,11 @@ def __call__(self, x, temb, textcontext): force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False), norm_inputs=middle_attention.get("norm_inputs", True), explicitly_add_residual=middle_attention.get("explicitly_add_residual", True), - kernel_init=self.kernel_init(scale=1.0), name=f"middle_attention_{j}")(x, textcontext) x = ResidualBlock( middle_conv_type, name=f"middle_res2_{j}", features=middle_dim_out, - kernel_init=self.kernel_init(scale=1.0), kernel_size=(3, 3), strides=(1, 1), activation=self.activation, @@ -157,7 +150,6 @@ def __call__(self, x, temb, textcontext): up_conv_type,# if j == 0 else "separable", name=f"up_{i}_residual_{j}", features=dim_out, - kernel_init=self.kernel_init(scale=1.0), kernel_size=kernel_size, strides=(1, 1), activation=self.activation, @@ -177,7 +169,6 @@ def __call__(self, x, temb, textcontext): force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False), norm_inputs=attention_config.get("norm_inputs", True), explicitly_add_residual=attention_config.get("explicitly_add_residual", True), - kernel_init=self.kernel_init(scale=1.0), name=f"up_{i}_attention_{j}")(x, textcontext) # print("Upscaling ", i, x.shape) if i != len(feature_depths) - 1: @@ -196,7 +187,6 @@ def __call__(self, x, temb, textcontext): features=self.feature_depths[0], kernel_size=(3, 3), strides=(1, 1), - kernel_init=self.kernel_init(scale=1.0), dtype=self.dtype, precision=self.precision )(x) @@ -207,7 +197,6 @@ def __call__(self, x, temb, textcontext): conv_type, name="final_residual", features=self.feature_depths[0], - kernel_init=self.kernel_init(scale=1.0), kernel_size=(3,3), strides=(1, 1), activation=self.activation, @@ -226,7 +215,7 @@ def __call__(self, x, temb, textcontext): kernel_size=(3, 3), strides=(1, 1), # activation=jax.nn.mish - kernel_init=self.kernel_init(scale=0.0), + # kernel_init=self.kernel_init(scale=0.0), dtype=self.dtype, precision=self.precision )(x) diff --git a/flaxdiff/models/simple_vit.py b/flaxdiff/models/simple_vit.py index 530b89d..0745e1e 100644 --- a/flaxdiff/models/simple_vit.py +++ b/flaxdiff/models/simple_vit.py @@ -23,7 +23,6 @@ class PatchEmbedding(nn.Module): embedding_dim: int dtype: Any = jnp.float32 precision: Any = jax.lax.Precision.HIGH - kernel_init: Callable = partial(kernel_init, 1.0) @nn.compact def __call__(self, x): @@ -34,7 +33,6 @@ def __call__(self, x): kernel_size=(self.patch_size, self.patch_size), strides=(self.patch_size, self.patch_size), dtype=self.dtype, - kernel_init=self.kernel_init(), precision=self.precision)(x) x = jnp.reshape(x, (batch, -1, self.embedding_dim)) return x @@ -67,7 +65,7 @@ class UViT(nn.Module): norm_groups:int=8 dtype: Optional[Dtype] = None precision: PrecisionLike = None - kernel_init: Callable = partial(kernel_init, scale=1.0) + # kernel_init: Callable = partial(kernel_init, scale=1.0) add_residualblock_output: bool = False norm_inputs: bool = False explicitly_add_residual: bool = True @@ -88,10 +86,10 @@ def __call__(self, x, temb, textcontext=None): # Patch embedding x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features, - dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x) + dtype=self.dtype, precision=self.precision)(x) num_patches = x.shape[1] - context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), + context_emb = nn.DenseGeneral(features=self.emb_features, dtype=self.dtype, precision=self.precision)(textcontext) num_text_tokens = textcontext.shape[1] @@ -116,7 +114,7 @@ def __call__(self, x, temb, textcontext=None): only_pure_attention=False, norm_inputs=self.norm_inputs, explicitly_add_residual=self.explicitly_add_residual, - kernel_init=self.kernel_init())(x) + )(x) skips.append(x) # Middle block @@ -126,12 +124,12 @@ def __call__(self, x, temb, textcontext=None): only_pure_attention=False, norm_inputs=self.norm_inputs, explicitly_add_residual=self.explicitly_add_residual, - kernel_init=self.kernel_init())(x) + )(x) # # Out blocks for i in range(self.num_layers // 2): x = jnp.concatenate([x, skips.pop()], axis=-1) - x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(), + x = nn.DenseGeneral(features=self.emb_features, dtype=self.dtype, precision=self.precision)(x) x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads, dtype=self.dtype, precision=self.precision, use_projection=self.use_projection, @@ -139,13 +137,13 @@ def __call__(self, x, temb, textcontext=None): only_pure_attention=False, norm_inputs=self.norm_inputs, explicitly_add_residual=self.explicitly_add_residual, - kernel_init=self.kernel_init())(x) + )(x) # print(f'Shape of x after transformer blocks: {x.shape}') x = self.norm()(x) patch_dim = self.patch_size ** 2 * self.output_channels - x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x) + x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x) x = x[:, 1 + num_text_tokens:, :] x = unpatchify(x, channels=self.output_channels) @@ -159,7 +157,6 @@ def __call__(self, x, temb, textcontext=None): kernel_size=(3, 3), strides=(1, 1), # activation=jax.nn.mish - kernel_init=self.kernel_init(scale=0.0), dtype=self.dtype, precision=self.precision )(x) @@ -173,7 +170,6 @@ def __call__(self, x, temb, textcontext=None): kernel_size=(3, 3), strides=(1, 1), # activation=jax.nn.mish - kernel_init=self.kernel_init(scale=0.0), dtype=self.dtype, precision=self.precision )(x) diff --git a/flaxdiff/trainer/diffusion_trainer.py b/flaxdiff/trainer/diffusion_trainer.py index 97bd80d..53f7978 100644 --- a/flaxdiff/trainer/diffusion_trainer.py +++ b/flaxdiff/trainer/diffusion_trainer.py @@ -231,11 +231,11 @@ def model_loss(params): ), ) - train_state = new_state.apply_ema(self.ema_decay) + new_state = new_state.apply_ema(self.ema_decay) if distributed_training: loss = jax.lax.pmean(loss, "data") - return train_state, loss, rng_state + return new_state, loss, rng_state if distributed_training: train_step = shard_map( diff --git a/flaxdiff/trainer/simple_trainer.py b/flaxdiff/trainer/simple_trainer.py index 4408775..a748444 100644 --- a/flaxdiff/trainer/simple_trainer.py +++ b/flaxdiff/trainer/simple_trainer.py @@ -159,7 +159,7 @@ def __init__(self, self.best_loss = 1e9 def get_input_ones(self): - return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()} + return {k: jnp.ones((1, *v), dtype=self.model.dtype) for k, v in self.input_shapes.items()} def generate_states( self, @@ -437,12 +437,30 @@ def train_loop( # If the loss is too low, we can assume the model has diverged print(colored(f"Loss too low at step {current_step} => {loss}", 'red')) # Reset the model to the old state - if self.best_state is not None: - print(colored(f"Resetting model to best state", 'red')) - train_state = self.best_state - loss = self.best_loss + # if self.best_state is not None: + # print(colored(f"Resetting model to best state", 'red')) + # train_state = self.best_state + # loss = self.best_loss + # else: + # exit(1) + + # Check if there are any NaN/inf values in the train_state.params + params = train_state.params + if isinstance(params, dict): + for key, value in params.items(): + if isinstance(value, jnp.ndarray): + if jnp.isnan(value).any() or jnp.isinf(value).any(): + print(colored(f"NaN/inf values found in params at step {current_step}", 'red')) + # Reset the model to the old state + # train_state = self.best_state + # loss = self.best_loss + # break + else: + print(colored(f"Params are fine at step {current_step}", 'green')) else: - exit(1) + print(colored(f"Params are not a dict at step {current_step}", 'red')) + + exit(1) epoch_loss += loss current_step += 1 diff --git a/prototype_pipeline.ipynb b/prototype_pipeline.ipynb index 00c0bbb..4b376fc 100644 --- a/prototype_pipeline.ipynb +++ b/prototype_pipeline.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -45,22 +45,22 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2025-04-10 06:23:43.248339: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2025-04-10 15:23:13.709672: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", - "E0000 00:00:1744266223.273050 2055796 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", - "E0000 00:00:1744266223.280744 2055796 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", - "W0000 00:00:1744266223.298347 2055796 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1744266223.298373 2055796 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1744266223.298376 2055796 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "W0000 00:00:1744266223.298378 2055796 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", - "Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing FlaxCLIPTextModel: {('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('logit_scale',), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('visual_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'post_layernorm', 'scale'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('text_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'pre_layrnorm', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'post_layernorm', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel')}\n", + "E0000 00:00:1744298593.733614 2309744 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1744298593.741021 2309744 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "W0000 00:00:1744298593.758653 2309744 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1744298593.758673 2309744 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1744298593.758675 2309744 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "W0000 00:00:1744298593.758677 2309744 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n", + "Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing FlaxCLIPTextModel: {('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'bias'), ('visual_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'bias'), ('logit_scale',), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'kernel'), ('vision_model', 'embeddings', 'position_embedding', 'embedding'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'bias'), ('vision_model', 'embeddings', 'patch_embedding', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'post_layernorm', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc2', 'kernel'), ('vision_model', 'pre_layrnorm', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '5', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'post_layernorm', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'pre_layrnorm', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '3', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '16', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc1', 'bias'), ('text_projection', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '13', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '6', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '20', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '21', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '21', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '10', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '3', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '14', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '2', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '22', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '5', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '21', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '13', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '7', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm2', 'bias'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'encoder', 'layers', '20', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '5', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '23', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '22', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm2', 'scale'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '10', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '1', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '11', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '14', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '6', 'self_attn', 'v_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '8', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '9', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '8', 'layer_norm1', 'bias'), ('vision_model', 'encoder', 'layers', '0', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '19', 'self_attn', 'k_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '23', 'mlp', 'fc1', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '22', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '15', 'self_attn', 'v_proj', 'bias'), ('vision_model', 'embeddings', 'class_embedding'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'q_proj', 'bias'), ('vision_model', 'encoder', 'layers', '18', 'layer_norm1', 'scale'), ('vision_model', 'encoder', 'layers', '17', 'mlp', 'fc1', 'bias'), ('vision_model', 'encoder', 'layers', '17', 'self_attn', 'out_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '12', 'mlp', 'fc2', 'bias'), ('vision_model', 'encoder', 'layers', '9', 'self_attn', 'k_proj', 'bias'), ('vision_model', 'encoder', 'layers', '16', 'mlp', 'fc2', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'q_proj', 'kernel'), ('vision_model', 'encoder', 'layers', '4', 'self_attn', 'out_proj', 'bias'), ('vision_model', 'encoder', 'layers', '13', 'mlp', 'fc2', 'kernel')}\n", "- This IS expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing FlaxCLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -91,9 +91,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/torch_xla/__init__.py:251: UserWarning: `tensorflow` can conflict with `torch-xla`. Prefer `tensorflow-cpu` when using PyTorch/XLA. To silence this warning, `pip uninstall -y tensorflow && pip install tensorflow-cpu`. If you are in a notebook environment such as Colab or Kaggle, restart your notebook runtime afterwards.\n", + " warnings.warn(\n", + "WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.\n" + ] + } + ], "source": [ "from flax import linen as nn\n", "from diffusers import FlaxUNet2DConditionModel\n", @@ -144,68 +154,7 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# model = BCHWModelWrapper(unet_model)\n", - "params = unet.init(jax.random.PRNGKey(0), jnp.ones((1, IMAGE_SIZE, IMAGE_SIZE, 3)), jnp.ones((1,)), jnp.ones((1, 77, 768)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [ - { - "ename": "XlaRuntimeError", - "evalue": "RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 16.00G. That was not possible. There are 13.93G free.; (0x0x0_HBM0)", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mXlaRuntimeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[8]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m out = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mIMAGE_SIZE\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mIMAGE_SIZE\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mjnp\u001b[49m\u001b[43m.\u001b[49m\u001b[43mones\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[32;43m4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m77\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[32;43m768\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - " \u001b[31m[... skipping hidden 6 frame]\u001b[39m\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[5]\u001b[39m\u001b[32m, line 35\u001b[39m, in \u001b[36mBCHWModelWrapper.__call__\u001b[39m\u001b[34m(self, x, temb, textcontext)\u001b[39m\n\u001b[32m 33\u001b[39m x = jnp.transpose(x, (\u001b[32m0\u001b[39m, \u001b[32m3\u001b[39m, \u001b[32m1\u001b[39m, \u001b[32m2\u001b[39m))\n\u001b[32m 34\u001b[39m \u001b[38;5;66;03m# Pass the input through the UNet model\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m35\u001b[39m out = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 36\u001b[39m \u001b[43m \u001b[49m\u001b[43msample\u001b[49m\u001b[43m=\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 37\u001b[39m \u001b[43m \u001b[49m\u001b[43mtimesteps\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtemb\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 38\u001b[39m \u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtextcontext\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 39\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 40\u001b[39m \u001b[38;5;66;03m# Reshape the output back to BHWC format\u001b[39;00m\n\u001b[32m 41\u001b[39m out = jnp.transpose(out.sample, (\u001b[32m0\u001b[39m, \u001b[32m2\u001b[39m, \u001b[32m3\u001b[39m, \u001b[32m1\u001b[39m))\n", - " \u001b[31m[... skipping hidden 2 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/diffusers/models/unets/unet_2d_condition_flax.py:407\u001b[39m, in \u001b[36mFlaxUNet2DConditionModel.__call__\u001b[39m\u001b[34m(self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, return_dict, train)\u001b[39m\n\u001b[32m 405\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m down_block \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.down_blocks:\n\u001b[32m 406\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(down_block, FlaxCrossAttnDownBlock2D):\n\u001b[32m--> \u001b[39m\u001b[32m407\u001b[39m sample, res_samples = \u001b[43mdown_block\u001b[49m\u001b[43m(\u001b[49m\u001b[43msample\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt_emb\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtrain\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 408\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 409\u001b[39m sample, res_samples = down_block(sample, t_emb, deterministic=\u001b[38;5;129;01mnot\u001b[39;00m train)\n", - " \u001b[31m[... skipping hidden 2 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/diffusers/models/unets/unet_2d_blocks_flax.py:101\u001b[39m, in \u001b[36mFlaxCrossAttnDownBlock2D.__call__\u001b[39m\u001b[34m(self, hidden_states, temb, encoder_hidden_states, deterministic)\u001b[39m\n\u001b[32m 99\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m resnet, attn \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(\u001b[38;5;28mself\u001b[39m.resnets, \u001b[38;5;28mself\u001b[39m.attentions):\n\u001b[32m 100\u001b[39m hidden_states = resnet(hidden_states, temb, deterministic=deterministic)\n\u001b[32m--> \u001b[39m\u001b[32m101\u001b[39m hidden_states = \u001b[43mattn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoder_hidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 102\u001b[39m output_states += (hidden_states,)\n\u001b[32m 104\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.add_downsample:\n", - " \u001b[31m[... skipping hidden 2 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/diffusers/models/attention_flax.py:421\u001b[39m, in \u001b[36mFlaxTransformer2DModel.__call__\u001b[39m\u001b[34m(self, hidden_states, context, deterministic)\u001b[39m\n\u001b[32m 418\u001b[39m hidden_states = hidden_states.reshape(batch, height * width, channels)\n\u001b[32m 420\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m transformer_block \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m.transformer_blocks:\n\u001b[32m--> \u001b[39m\u001b[32m421\u001b[39m hidden_states = \u001b[43mtransformer_block\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 423\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.use_linear_projection:\n\u001b[32m 424\u001b[39m hidden_states = \u001b[38;5;28mself\u001b[39m.proj_out(hidden_states)\n", - " \u001b[31m[... skipping hidden 2 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/diffusers/models/attention_flax.py:312\u001b[39m, in \u001b[36mFlaxBasicTransformerBlock.__call__\u001b[39m\u001b[34m(self, hidden_states, context, deterministic)\u001b[39m\n\u001b[32m 310\u001b[39m hidden_states = \u001b[38;5;28mself\u001b[39m.attn1(\u001b[38;5;28mself\u001b[39m.norm1(hidden_states), context, deterministic=deterministic)\n\u001b[32m 311\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m--> \u001b[39m\u001b[32m312\u001b[39m hidden_states = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mattn1\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mnorm1\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhidden_states\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdeterministic\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 313\u001b[39m hidden_states = hidden_states + residual\n\u001b[32m 315\u001b[39m \u001b[38;5;66;03m# cross attention\u001b[39;00m\n", - " \u001b[31m[... skipping hidden 2 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/diffusers/models/attention_flax.py:228\u001b[39m, in \u001b[36mFlaxAttention.__call__\u001b[39m\u001b[34m(self, hidden_states, context, deterministic)\u001b[39m\n\u001b[32m 225\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 226\u001b[39m attention_scores = jnp.einsum(\u001b[33m\"\u001b[39m\u001b[33mb i d, b j d->b i j\u001b[39m\u001b[33m\"\u001b[39m, query_states, key_states)\n\u001b[32m--> \u001b[39m\u001b[32m228\u001b[39m attention_scores = \u001b[43mattention_scores\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mscale\u001b[49m\n\u001b[32m 229\u001b[39m attention_probs = nn.softmax(attention_scores, axis=-\u001b[32m1\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.split_head_dim \u001b[38;5;28;01melse\u001b[39;00m \u001b[32m2\u001b[39m)\n\u001b[32m 231\u001b[39m \u001b[38;5;66;03m# attend to values\u001b[39;00m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:579\u001b[39m, in \u001b[36m_defer_to_unrecognized_arg..deferring_binary_op\u001b[39m\u001b[34m(self, other)\u001b[39m\n\u001b[32m 577\u001b[39m args = (other, \u001b[38;5;28mself\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m swap \u001b[38;5;28;01melse\u001b[39;00m (\u001b[38;5;28mself\u001b[39m, other)\n\u001b[32m 578\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(other, _accepted_binop_types):\n\u001b[32m--> \u001b[39m\u001b[32m579\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbinary_op\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 580\u001b[39m \u001b[38;5;66;03m# Note: don't use isinstance here, because we don't want to raise for\u001b[39;00m\n\u001b[32m 581\u001b[39m \u001b[38;5;66;03m# subclasses, e.g. NamedTuple objects that may override operators.\u001b[39;00m\n\u001b[32m 582\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mtype\u001b[39m(other) \u001b[38;5;129;01min\u001b[39;00m _rejected_binop_types:\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/jax/_src/numpy/ufunc_api.py:180\u001b[39m, in \u001b[36mufunc.__call__\u001b[39m\u001b[34m(self, out, where, *args)\u001b[39m\n\u001b[32m 178\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mwhere argument of \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 179\u001b[39m call = \u001b[38;5;28mself\u001b[39m.__static_props[\u001b[33m'\u001b[39m\u001b[33mcall\u001b[39m\u001b[33m'\u001b[39m] \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._call_vectorized\n\u001b[32m--> \u001b[39m\u001b[32m180\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n", - " \u001b[31m[... skipping hidden 5 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py:1297\u001b[39m, in \u001b[36mExecuteReplicated.__call__\u001b[39m\u001b[34m(self, *args)\u001b[39m\n\u001b[32m 1295\u001b[39m \u001b[38;5;28mself\u001b[39m._handle_token_bufs(result_token_bufs, sharded_runtime_token)\n\u001b[32m 1296\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1297\u001b[39m results = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mxla_executable\u001b[49m\u001b[43m.\u001b[49m\u001b[43mexecute_sharded\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_bufs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1299\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m dispatch.needs_check_special():\n\u001b[32m 1300\u001b[39m out_arrays = results.disassemble_into_single_device_arrays()\n", - "\u001b[31mXlaRuntimeError\u001b[39m: RESOURCE_EXHAUSTED: Error allocating device buffer: Attempting to allocate 16.00G. That was not possible. There are 13.93G free.; (0x0x0_HBM0)" - ] - } - ], - "source": [ - "out = unet.apply(params, jnp.ones((4,IMAGE_SIZE, IMAGE_SIZE, 3)), jnp.ones((4,)), jnp.ones((4, 77, 768)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -225,7 +174,9 @@ " {\"heads\":8, \"dtype\":jnp.float32, \"flash_attention\":False, \"use_projection\":False, \"use_self_and_cross\":False}\n", " ],\n", " num_res_blocks=2,\n", - " num_middle_res_blocks=1\n", + " num_middle_res_blocks=1,\n", + " dtype=jnp.bfloat16,\n", + " \n", ")" ] }, @@ -257,7 +208,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /home/mrwhite0racle/persist/FlaxDiff/wandb/run-20250410_062234-6p13k0ip" + "Run data is saved locally in /home/mrwhite0racle/persist/FlaxDiff/wandb/run-20250410_152327-lqhfkv5j" ], "text/plain": [ "" @@ -269,7 +220,7 @@ { "data": { "text/html": [ - "Syncing run prototype-2025-04-10_06:22:32 to Weights & Biases (docs)
" + "Syncing run prototype-2025-04-10_15:23:26 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -293,7 +244,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/ashishkumar4/mlops-msml605-project/runs/6p13k0ip" + " View run at https://wandb.ai/ashishkumar4/mlops-msml605-project/runs/lqhfkv5j" ], "text/plain": [ "" @@ -343,5471 +294,10 @@ ")\n" ] }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
                                                                                        BCHWModelWrapper Summary                                                                                        \n",
-       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃ path                                                       module                       inputs                                      outputs                        params                        ┃\n",
-       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│                                                           │ BCHWModelWrapper            │ temb: float32[1]                           │ bfloat16[1,128,128,3]         │                               │\n",
-       "│                                                           │                             │ textcontext: float32[1,77,768]             │                               │                               │\n",
-       "│                                                           │                             │ x: float32[1,128,128,3]                    │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model                                                     │ FlaxUNet2DConditionModel    │ encoder_hidden_states: float32[1,77,768]   │ sample: bfloat16[1,3,128,128] │                               │\n",
-       "│                                                           │                             │ sample: float32[1,3,128,128]               │                               │                               │\n",
-       "│                                                           │                             │ timesteps: float32[1]                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/time_proj                                           │ FlaxTimesteps               │ float32[1]                                 │ float32[1,64]                 │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/time_embedding                                      │ FlaxTimestepEmbedding       │ float32[1,64]                              │ bfloat16[1,256]               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/time_embedding/linear_1                             │ Dense                       │ float32[1,64]                              │ bfloat16[1,256]               │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,256]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,640 (66.6 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/time_embedding/linear_2                             │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,256]               │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/conv_in                                             │ Conv                        │ float32[1,128,128,3]                       │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,3,64]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,792 (7.2 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0                                       │ FlaxCrossAttnDownBlock2D    │ - bfloat16[1,128,128,64]                   │ - bfloat16[1,64,64,64]        │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │ - - bfloat16[1,128,128,64]    │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │   - bfloat16[1,128,128,64]    │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │   - bfloat16[1,64,64,64]      │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_0                             │ FlaxResnetBlock2D           │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_0/norm1                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_0/conv1                       │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_0/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,64]                │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_0/norm2                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_0/dropout                     │ Dropout                     │ - float32[1,128,128,64]                    │ float32[1,128,128,64]         │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_0/conv2                       │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0                          │ FlaxTransformer2DModel      │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/norm                     │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/proj_in                  │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0     │ FlaxBasicTransformerBlock   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff  │ FlaxFeedForward             │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ FlaxGEGLU                   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,512]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 33,280 (133.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ Dropout                     │ - bfloat16[1,16384,256]                    │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,16384,256]                      │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/transformer_blocks_0/dr… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/proj_out                 │ Conv                        │ bfloat16[1,128,128,64]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_0/dropout_layer            │ Dropout                     │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_1                             │ FlaxResnetBlock2D           │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_1/norm1                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_1/conv1                       │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_1/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,64]                │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_1/norm2                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_1/dropout                     │ Dropout                     │ - float32[1,128,128,64]                    │ float32[1,128,128,64]         │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/resnets_1/conv2                       │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1                          │ FlaxTransformer2DModel      │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/norm                     │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/proj_in                  │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0     │ FlaxBasicTransformerBlock   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff  │ FlaxFeedForward             │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ FlaxGEGLU                   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,512]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 33,280 (133.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ Dropout                     │ - bfloat16[1,16384,256]                    │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,16384,256]                      │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/transformer_blocks_0/dr… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/proj_out                 │ Conv                        │ bfloat16[1,128,128,64]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/attentions_1/dropout_layer            │ Dropout                     │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/downsamplers_0                        │ FlaxDownsample2D            │ bfloat16[1,128,128,64]                     │ bfloat16[1,64,64,64]          │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_0/downsamplers_0/conv                   │ Conv                        │ bfloat16[1,128,128,64]                     │ bfloat16[1,64,64,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1                                       │ FlaxCrossAttnDownBlock2D    │ - bfloat16[1,64,64,64]                     │ - bfloat16[1,32,32,128]       │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │ - - bfloat16[1,64,64,128]     │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │   - bfloat16[1,64,64,128]     │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │   - bfloat16[1,32,32,128]     │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0                             │ FlaxResnetBlock2D           │ - bfloat16[1,64,64,64]                     │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0/norm1                       │ GroupNorm                   │ bfloat16[1,64,64,64]                       │ float32[1,64,64,64]           │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0/conv1                       │ Conv                        │ float32[1,64,64,64]                        │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,128]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 73,856 (295.4 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,128]               │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 32,896 (131.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0/norm2                       │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0/dropout                     │ Dropout                     │ - float32[1,64,64,128]                     │ float32[1,64,64,128]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0/conv2                       │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_0/conv_shortcut               │ Conv                        │ bfloat16[1,64,64,64]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,128]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 8,320 (33.3 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0                          │ FlaxTransformer2DModel      │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/norm                     │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/proj_in                  │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0     │ FlaxBasicTransformerBlock   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff  │ FlaxFeedForward             │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ FlaxGEGLU                   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,1024]         │ bias: float32[1024]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,1024]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 132,096 (528.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ Dropout                     │ - bfloat16[1,4096,512]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,4096,512]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,664 (262.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/transformer_blocks_0/dr… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/proj_out                 │ Conv                        │ bfloat16[1,64,64,128]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_0/dropout_layer            │ Dropout                     │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_1                             │ FlaxResnetBlock2D           │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_1/norm1                       │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_1/conv1                       │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_1/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,128]               │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 32,896 (131.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_1/norm2                       │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_1/dropout                     │ Dropout                     │ - float32[1,64,64,128]                     │ float32[1,64,64,128]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/resnets_1/conv2                       │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1                          │ FlaxTransformer2DModel      │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/norm                     │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/proj_in                  │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0     │ FlaxBasicTransformerBlock   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff  │ FlaxFeedForward             │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ FlaxGEGLU                   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,1024]         │ bias: float32[1024]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,1024]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 132,096 (528.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ Dropout                     │ - bfloat16[1,4096,512]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,4096,512]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,664 (262.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/transformer_blocks_0/dr… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/proj_out                 │ Conv                        │ bfloat16[1,64,64,128]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/attentions_1/dropout_layer            │ Dropout                     │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/downsamplers_0                        │ FlaxDownsample2D            │ bfloat16[1,64,64,128]                      │ bfloat16[1,32,32,128]         │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_1/downsamplers_0/conv                   │ Conv                        │ bfloat16[1,64,64,128]                      │ bfloat16[1,32,32,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2                                       │ FlaxCrossAttnDownBlock2D    │ - bfloat16[1,32,32,128]                    │ - bfloat16[1,16,16,256]       │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │ - - bfloat16[1,32,32,256]     │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │   - bfloat16[1,32,32,256]     │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │   - bfloat16[1,16,16,256]     │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0                             │ FlaxResnetBlock2D           │ - bfloat16[1,32,32,128]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0/norm1                       │ GroupNorm                   │ bfloat16[1,32,32,128]                      │ float32[1,32,32,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0/conv1                       │ Conv                        │ float32[1,32,32,128]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 295,168 (1.2 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,256]               │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0/norm2                       │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0/dropout                     │ Dropout                     │ - float32[1,32,32,256]                     │ float32[1,32,32,256]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0/conv2                       │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_0/conv_shortcut               │ Conv                        │ bfloat16[1,32,32,128]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 33,024 (132.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0                          │ FlaxTransformer2DModel      │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/norm                     │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/proj_in                  │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0     │ FlaxBasicTransformerBlock   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff  │ FlaxFeedForward             │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ FlaxGEGLU                   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,2048]         │ bias: float32[2048]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,2048]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 526,336 (2.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ Dropout                     │ - bfloat16[1,1024,1024]                    │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,1024,1024]                      │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1024,256]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,400 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/transformer_blocks_0/dr… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/proj_out                 │ Conv                        │ bfloat16[1,32,32,256]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_0/dropout_layer            │ Dropout                     │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_1                             │ FlaxResnetBlock2D           │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_1/norm1                       │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_1/conv1                       │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_1/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,256]               │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_1/norm2                       │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_1/dropout                     │ Dropout                     │ - float32[1,32,32,256]                     │ float32[1,32,32,256]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/resnets_1/conv2                       │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1                          │ FlaxTransformer2DModel      │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/norm                     │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/proj_in                  │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0     │ FlaxBasicTransformerBlock   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/no… │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff  │ FlaxFeedForward             │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ FlaxGEGLU                   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,2048]         │ bias: float32[2048]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,2048]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 526,336 (2.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ Dropout                     │ - bfloat16[1,1024,1024]                    │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ Dense                       │ bfloat16[1,1024,1024]                      │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1024,256]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,400 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/transformer_blocks_0/dr… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/proj_out                 │ Conv                        │ bfloat16[1,32,32,256]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/attentions_1/dropout_layer            │ Dropout                     │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/downsamplers_0                        │ FlaxDownsample2D            │ bfloat16[1,32,32,256]                      │ bfloat16[1,16,16,256]         │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_2/downsamplers_0/conv                   │ Conv                        │ bfloat16[1,32,32,256]                      │ bfloat16[1,16,16,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3                                       │ FlaxDownBlock2D             │ - bfloat16[1,16,16,256]                    │ - bfloat16[1,16,16,512]       │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │ - - bfloat16[1,16,16,512]     │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │   - bfloat16[1,16,16,512]     │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0                             │ FlaxResnetBlock2D           │ - bfloat16[1,16,16,256]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0/norm1                       │ GroupNorm                   │ bfloat16[1,16,16,256]                      │ float32[1,16,16,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0/conv1                       │ Conv                        │ float32[1,16,16,256]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,180,160 (4.7 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,512]               │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0/norm2                       │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0/dropout                     │ Dropout                     │ - float32[1,16,16,512]                     │ float32[1,16,16,512]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0/conv2                       │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_0/conv_shortcut               │ Conv                        │ bfloat16[1,16,16,256]                      │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_1                             │ FlaxResnetBlock2D           │ - bfloat16[1,16,16,512]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_1/norm1                       │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_1/conv1                       │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_1/time_emb_proj               │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,512]               │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_1/norm2                       │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_1/dropout                     │ Dropout                     │ - float32[1,16,16,512]                     │ float32[1,16,16,512]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/down_blocks_3/resnets_1/conv2                       │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block                                           │ FlaxUNetMidBlock2DCrossAttn │ - bfloat16[1,16,16,512]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_0                                 │ FlaxResnetBlock2D           │ - bfloat16[1,16,16,512]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_0/norm1                           │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_0/conv1                           │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_0/time_emb_proj                   │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,512]               │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_0/norm2                           │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_0/dropout                         │ Dropout                     │ - float32[1,16,16,512]                     │ float32[1,16,16,512]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_0/conv2                           │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0                              │ FlaxTransformer2DModel      │ - bfloat16[1,16,16,512]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/norm                         │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/proj_in                      │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,656 (1.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0         │ FlaxBasicTransformerBlock   │ - bfloat16[1,256,512]                      │ bfloat16[1,256,512]           │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/norm1   │ LayerNorm                   │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn1   │ FlaxAttention               │ - bfloat16[1,256,512]                      │ bfloat16[1,256,512]           │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense                       │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ kernel: float32[512,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,144 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense                       │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ kernel: float32[512,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,144 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense                       │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ kernel: float32[512,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,144 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense                       │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,656 (1.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dropout                     │ - bfloat16[1,256,512]                      │ bfloat16[1,256,512]           │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/norm2   │ LayerNorm                   │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn2   │ FlaxAttention               │ - bfloat16[1,256,512]                      │ bfloat16[1,256,512]           │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense                       │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ kernel: float32[512,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,144 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,512]            │ kernel: float32[768,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 393,216 (1.6 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,512]            │ kernel: float32[768,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 393,216 (1.6 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense                       │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,656 (1.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dropout                     │ - bfloat16[1,256,512]                      │ bfloat16[1,256,512]           │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/norm3   │ LayerNorm                   │ bfloat16[1,256,512]                        │ bfloat16[1,256,512]           │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/ff      │ FlaxFeedForward             │ - bfloat16[1,256,512]                      │ bfloat16[1,256,512]           │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ FlaxGEGLU                   │ - bfloat16[1,256,512]                      │ bfloat16[1,256,2048]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ Dense                       │ bfloat16[1,256,512]                        │ bfloat16[1,256,4096]          │ bias: float32[4096]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,4096]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,101,248 (8.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ Dropout                     │ - bfloat16[1,256,2048]                     │ bfloat16[1,256,2048]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ Dense                       │ bfloat16[1,256,2048]                       │ bfloat16[1,256,512]           │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[2048,512]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,049,088 (4.2 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/transformer_blocks_0/dropou… │ Dropout                     │ - bfloat16[1,256,512]                      │ bfloat16[1,256,512]           │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/proj_out                     │ Conv                        │ bfloat16[1,16,16,512]                      │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,656 (1.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/attentions_0/dropout_layer                │ Dropout                     │ - bfloat16[1,16,16,512]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_1                                 │ FlaxResnetBlock2D           │ - bfloat16[1,16,16,512]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_1/norm1                           │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_1/conv1                           │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_1/time_emb_proj                   │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,512]               │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_1/norm2                           │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_1/dropout                         │ Dropout                     │ - float32[1,16,16,512]                     │ float32[1,16,16,512]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/mid_block/resnets_1/conv2                           │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0                                         │ FlaxUpBlock2D               │ - bfloat16[1,16,16,512]                    │ bfloat16[1,32,32,512]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "│                                                           │                             │   res_hidden_states_tuple:                 │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,16,16,256]                  │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,16,16,512]                  │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,16,16,512]                  │                               │                               │\n",
-       "│                                                           │                             │   temb: bfloat16[1,256]                    │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0                               │ FlaxResnetBlock2D           │ - bfloat16[1,16,16,1024]                   │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0/norm1                         │ GroupNorm                   │ bfloat16[1,16,16,1024]                     │ float32[1,16,16,1024]         │ bias: float32[1024]           │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[1024]          │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,048 (8.2 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0/conv1                         │ Conv                        │ float32[1,16,16,1024]                      │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,1024,512] │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,719,104 (18.9 MB)           │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,512]               │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0/norm2                         │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0/dropout                       │ Dropout                     │ - float32[1,16,16,512]                     │ float32[1,16,16,512]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0/conv2                         │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_0/conv_shortcut                 │ Conv                        │ bfloat16[1,16,16,1024]                     │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,1024,512] │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 524,800 (2.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1                               │ FlaxResnetBlock2D           │ - bfloat16[1,16,16,1024]                   │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1/norm1                         │ GroupNorm                   │ bfloat16[1,16,16,1024]                     │ float32[1,16,16,1024]         │ bias: float32[1024]           │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[1024]          │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,048 (8.2 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1/conv1                         │ Conv                        │ float32[1,16,16,1024]                      │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,1024,512] │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,719,104 (18.9 MB)           │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,512]               │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1/norm2                         │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1/dropout                       │ Dropout                     │ - float32[1,16,16,512]                     │ float32[1,16,16,512]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1/conv2                         │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_1/conv_shortcut                 │ Conv                        │ bfloat16[1,16,16,1024]                     │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,1024,512] │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 524,800 (2.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2                               │ FlaxResnetBlock2D           │ - bfloat16[1,16,16,768]                    │ bfloat16[1,16,16,512]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2/norm1                         │ GroupNorm                   │ bfloat16[1,16,16,768]                      │ float32[1,16,16,768]          │ bias: float32[768]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[768]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,536 (6.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2/conv1                         │ Conv                        │ float32[1,16,16,768]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,768,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 3,539,456 (14.2 MB)           │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,512]               │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,512]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,584 (526.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2/norm2                         │ GroupNorm                   │ bfloat16[1,16,16,512]                      │ float32[1,16,16,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2/dropout                       │ Dropout                     │ - float32[1,16,16,512]                     │ float32[1,16,16,512]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2/conv2                         │ Conv                        │ float32[1,16,16,512]                       │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/resnets_2/conv_shortcut                 │ Conv                        │ bfloat16[1,16,16,768]                      │ bfloat16[1,16,16,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,768,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 393,728 (1.6 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/upsamplers_0                            │ FlaxUpsample2D              │ bfloat16[1,16,16,512]                      │ bfloat16[1,32,32,512]         │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_0/upsamplers_0/conv                       │ Conv                        │ bfloat16[1,32,32,512]                      │ bfloat16[1,32,32,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,512]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 2,359,808 (9.4 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1                                         │ FlaxCrossAttnUpBlock2D      │ - bfloat16[1,32,32,512]                    │ bfloat16[1,64,64,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "│                                                           │                             │   encoder_hidden_states: float32[1,77,768] │                               │                               │\n",
-       "│                                                           │                             │   res_hidden_states_tuple:                 │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,32,32,128]                  │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,32,32,256]                  │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,32,32,256]                  │                               │                               │\n",
-       "│                                                           │                             │   temb: bfloat16[1,256]                    │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0                               │ FlaxResnetBlock2D           │ - bfloat16[1,32,32,768]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0/norm1                         │ GroupNorm                   │ bfloat16[1,32,32,768]                      │ float32[1,32,32,768]          │ bias: float32[768]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[768]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,536 (6.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0/conv1                         │ Conv                        │ float32[1,32,32,768]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,768,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,769,728 (7.1 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,256]               │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0/norm2                         │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0/dropout                       │ Dropout                     │ - float32[1,32,32,256]                     │ float32[1,32,32,256]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0/conv2                         │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_0/conv_shortcut                 │ Conv                        │ bfloat16[1,32,32,768]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,768,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,864 (787.5 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0                            │ FlaxTransformer2DModel      │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/norm                       │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/proj_in                    │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,2048]         │ bias: float32[2048]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,2048]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 526,336 (2.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,1024,1024]                    │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,1024,1024]                      │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1024,256]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,400 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/proj_out                   │ Conv                        │ bfloat16[1,32,32,256]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_0/dropout_layer              │ Dropout                     │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1                               │ FlaxResnetBlock2D           │ - bfloat16[1,32,32,512]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1/norm1                         │ GroupNorm                   │ bfloat16[1,32,32,512]                      │ float32[1,32,32,512]          │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[512]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,024 (4.1 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1/conv1                         │ Conv                        │ float32[1,32,32,512]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,512,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,179,904 (4.7 MB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,256]               │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1/norm2                         │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1/dropout                       │ Dropout                     │ - float32[1,32,32,256]                     │ float32[1,32,32,256]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1/conv2                         │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_1/conv_shortcut                 │ Conv                        │ bfloat16[1,32,32,512]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,512,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 131,328 (525.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1                            │ FlaxTransformer2DModel      │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/norm                       │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/proj_in                    │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,2048]         │ bias: float32[2048]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,2048]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 526,336 (2.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,1024,1024]                    │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,1024,1024]                      │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1024,256]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,400 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/proj_out                   │ Conv                        │ bfloat16[1,32,32,256]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_1/dropout_layer              │ Dropout                     │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2                               │ FlaxResnetBlock2D           │ - bfloat16[1,32,32,384]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2/norm1                         │ GroupNorm                   │ bfloat16[1,32,32,384]                      │ float32[1,32,32,384]          │ bias: float32[384]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[384]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 768 (3.1 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2/conv1                         │ Conv                        │ float32[1,32,32,384]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,384,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 884,992 (3.5 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,256]               │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2/norm2                         │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2/dropout                       │ Dropout                     │ - float32[1,32,32,256]                     │ float32[1,32,32,256]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2/conv2                         │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/resnets_2/conv_shortcut                 │ Conv                        │ bfloat16[1,32,32,384]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,384,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,560 (394.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2                            │ FlaxTransformer2DModel      │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/norm                       │ GroupNorm                   │ bfloat16[1,32,32,256]                      │ float32[1,32,32,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/proj_in                    │ Conv                        │ float32[1,32,32,256]                       │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,536 (262.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,256]            │ kernel: float32[768,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 196,608 (786.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,256]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,1024,256]                       │ bfloat16[1,1024,2048]         │ bias: float32[2048]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,2048]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 526,336 (2.1 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,1024,1024]                    │ bfloat16[1,1024,1024]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,1024,1024]                      │ bfloat16[1,1024,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1024,256]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 262,400 (1.0 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,1024,256]                     │ bfloat16[1,1024,256]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/proj_out                   │ Conv                        │ bfloat16[1,32,32,256]                      │ bfloat16[1,32,32,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,792 (263.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/attentions_2/dropout_layer              │ Dropout                     │ - bfloat16[1,32,32,256]                    │ bfloat16[1,32,32,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/upsamplers_0                            │ FlaxUpsample2D              │ bfloat16[1,32,32,256]                      │ bfloat16[1,64,64,256]         │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_1/upsamplers_0/conv                       │ Conv                        │ bfloat16[1,64,64,256]                      │ bfloat16[1,64,64,256]         │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,256]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 590,080 (2.4 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2                                         │ FlaxCrossAttnUpBlock2D      │ - bfloat16[1,64,64,256]                    │ bfloat16[1,128,128,128]       │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "│                                                           │                             │   encoder_hidden_states: float32[1,77,768] │                               │                               │\n",
-       "│                                                           │                             │   res_hidden_states_tuple:                 │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,64,64,64]                   │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,64,64,128]                  │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,64,64,128]                  │                               │                               │\n",
-       "│                                                           │                             │   temb: bfloat16[1,256]                    │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0                               │ FlaxResnetBlock2D           │ - bfloat16[1,64,64,384]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0/norm1                         │ GroupNorm                   │ bfloat16[1,64,64,384]                      │ float32[1,64,64,384]          │ bias: float32[384]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[384]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 768 (3.1 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0/conv1                         │ Conv                        │ float32[1,64,64,384]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,384,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 442,496 (1.8 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,128]               │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 32,896 (131.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0/norm2                         │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0/dropout                       │ Dropout                     │ - float32[1,64,64,128]                     │ float32[1,64,64,128]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0/conv2                         │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_0/conv_shortcut                 │ Conv                        │ bfloat16[1,64,64,384]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,384,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,280 (197.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0                            │ FlaxTransformer2DModel      │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/norm                       │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/proj_in                    │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,1024]         │ bias: float32[1024]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,1024]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 132,096 (528.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,4096,512]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,4096,512]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,664 (262.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/proj_out                   │ Conv                        │ bfloat16[1,64,64,128]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_0/dropout_layer              │ Dropout                     │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1                               │ FlaxResnetBlock2D           │ - bfloat16[1,64,64,256]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1/norm1                         │ GroupNorm                   │ bfloat16[1,64,64,256]                      │ float32[1,64,64,256]          │ bias: float32[256]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[256]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 512 (2.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1/conv1                         │ Conv                        │ float32[1,64,64,256]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,256,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 295,040 (1.2 MB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,128]               │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 32,896 (131.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1/norm2                         │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1/dropout                       │ Dropout                     │ - float32[1,64,64,128]                     │ float32[1,64,64,128]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1/conv2                         │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_1/conv_shortcut                 │ Conv                        │ bfloat16[1,64,64,256]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,256,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 32,896 (131.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1                            │ FlaxTransformer2DModel      │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/norm                       │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/proj_in                    │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,1024]         │ bias: float32[1024]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,1024]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 132,096 (528.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,4096,512]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,4096,512]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,664 (262.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/proj_out                   │ Conv                        │ bfloat16[1,64,64,128]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_1/dropout_layer              │ Dropout                     │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2                               │ FlaxResnetBlock2D           │ - bfloat16[1,64,64,192]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2/norm1                         │ GroupNorm                   │ bfloat16[1,64,64,192]                      │ float32[1,64,64,192]          │ bias: float32[192]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[192]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 384 (1.5 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2/conv1                         │ Conv                        │ float32[1,64,64,192]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,192,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 221,312 (885.2 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,128]               │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 32,896 (131.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2/norm2                         │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2/dropout                       │ Dropout                     │ - float32[1,64,64,128]                     │ float32[1,64,64,128]          │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2/conv2                         │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/resnets_2/conv_shortcut                 │ Conv                        │ bfloat16[1,64,64,192]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,192,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 24,704 (98.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2                            │ FlaxTransformer2DModel      │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/norm                       │ GroupNorm                   │ bfloat16[1,64,64,128]                      │ float32[1,64,64,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/proj_in                    │ Conv                        │ float32[1,64,64,128]                       │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,384 (65.5 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,128]            │ kernel: float32[768,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 98,304 (393.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,4096,128]                       │ bfloat16[1,4096,1024]         │ bias: float32[1024]           │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[128,1024]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 132,096 (528.4 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,4096,512]                     │ bfloat16[1,4096,512]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,4096,512]                       │ bfloat16[1,4096,128]          │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[512,128]      │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 65,664 (262.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,4096,128]                     │ bfloat16[1,4096,128]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/proj_out                   │ Conv                        │ bfloat16[1,64,64,128]                      │ bfloat16[1,64,64,128]         │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,512 (66.0 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/attentions_2/dropout_layer              │ Dropout                     │ - bfloat16[1,64,64,128]                    │ bfloat16[1,64,64,128]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/upsamplers_0                            │ FlaxUpsample2D              │ bfloat16[1,64,64,128]                      │ bfloat16[1,128,128,128]       │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_2/upsamplers_0/conv                       │ Conv                        │ bfloat16[1,128,128,128]                    │ bfloat16[1,128,128,128]       │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,128]  │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 147,584 (590.3 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3                                         │ FlaxCrossAttnUpBlock2D      │ - bfloat16[1,128,128,128]                  │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "│                                                           │                             │   encoder_hidden_states: float32[1,77,768] │                               │                               │\n",
-       "│                                                           │                             │   res_hidden_states_tuple:                 │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,128,128,64]                 │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,128,128,64]                 │                               │                               │\n",
-       "│                                                           │                             │   - bfloat16[1,128,128,64]                 │                               │                               │\n",
-       "│                                                           │                             │   temb: bfloat16[1,256]                    │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0                               │ FlaxResnetBlock2D           │ - bfloat16[1,128,128,192]                  │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0/norm1                         │ GroupNorm                   │ bfloat16[1,128,128,192]                    │ float32[1,128,128,192]        │ bias: float32[192]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[192]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 384 (1.5 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0/conv1                         │ Conv                        │ float32[1,128,128,192]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,192,64]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 110,656 (442.6 KB)            │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,64]                │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0/norm2                         │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0/dropout                       │ Dropout                     │ - float32[1,128,128,64]                    │ float32[1,128,128,64]         │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0/conv2                         │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_0/conv_shortcut                 │ Conv                        │ bfloat16[1,128,128,192]                    │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,192,64]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 12,352 (49.4 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0                            │ FlaxTransformer2DModel      │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/norm                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/proj_in                    │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,512]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 33,280 (133.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,16384,256]                    │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,16384,256]                      │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/proj_out                   │ Conv                        │ bfloat16[1,128,128,64]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_0/dropout_layer              │ Dropout                     │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1                               │ FlaxResnetBlock2D           │ - bfloat16[1,128,128,128]                  │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1/norm1                         │ GroupNorm                   │ bfloat16[1,128,128,128]                    │ float32[1,128,128,128]        │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1/conv1                         │ Conv                        │ float32[1,128,128,128]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,64]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 73,792 (295.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,64]                │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1/norm2                         │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1/dropout                       │ Dropout                     │ - float32[1,128,128,64]                    │ float32[1,128,128,64]         │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1/conv2                         │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_1/conv_shortcut                 │ Conv                        │ bfloat16[1,128,128,128]                    │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,64]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 8,256 (33.0 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1                            │ FlaxTransformer2DModel      │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/norm                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/proj_in                    │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,512]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 33,280 (133.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,16384,256]                    │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,16384,256]                      │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/proj_out                   │ Conv                        │ bfloat16[1,128,128,64]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_1/dropout_layer              │ Dropout                     │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2                               │ FlaxResnetBlock2D           │ - bfloat16[1,128,128,128]                  │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - bfloat16[1,256]                          │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2/norm1                         │ GroupNorm                   │ bfloat16[1,128,128,128]                    │ float32[1,128,128,128]        │ bias: float32[128]            │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[128]           │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 256 (1.0 KB)                  │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2/conv1                         │ Conv                        │ float32[1,128,128,128]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,128,64]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 73,792 (295.2 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2/time_emb_proj                 │ Dense                       │ bfloat16[1,256]                            │ bfloat16[1,64]                │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2/norm2                         │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2/dropout                       │ Dropout                     │ - float32[1,128,128,64]                    │ float32[1,128,128,64]         │                               │\n",
-       "│                                                           │                             │ - True                                     │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2/conv2                         │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 36,928 (147.7 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/resnets_2/conv_shortcut                 │ Conv                        │ bfloat16[1,128,128,128]                    │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,128,64]   │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 8,256 (33.0 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2                            │ FlaxTransformer2DModel      │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/norm                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/proj_in                    │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0       │ FlaxBasicTransformerBlock   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/norm1 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn1 │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/norm2 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn2 │ FlaxAttention               │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - float32[1,77,768]                        │                               │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,096 (16.4 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ float32[1,77,768]                          │ bfloat16[1,77,64]             │ kernel: float32[768,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 49,152 (196.6 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,64]        │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/norm3 │ LayerNorm                   │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff    │ FlaxFeedForward             │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ FlaxGEGLU                   │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,16384,64]                       │ bfloat16[1,16384,512]         │ bias: float32[512]            │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[64,512]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 33,280 (133.1 KB)             │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ Dropout                     │ - bfloat16[1,16384,256]                    │ bfloat16[1,16384,256]         │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ Dense                       │ bfloat16[1,16384,256]                      │ bfloat16[1,16384,64]          │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[256,64]       │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 16,448 (65.8 KB)              │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/transformer_blocks_0/drop… │ Dropout                     │ - bfloat16[1,16384,64]                     │ bfloat16[1,16384,64]          │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/proj_out                   │ Conv                        │ bfloat16[1,128,128,64]                     │ bfloat16[1,128,128,64]        │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[1,1,64,64]    │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 4,160 (16.6 KB)               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/up_blocks_3/attentions_2/dropout_layer              │ Dropout                     │ - bfloat16[1,128,128,64]                   │ bfloat16[1,128,128,64]        │                               │\n",
-       "│                                                           │                             │ - deterministic: True                      │                               │                               │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/conv_norm_out                                       │ GroupNorm                   │ bfloat16[1,128,128,64]                     │ float32[1,128,128,64]         │ bias: float32[64]             │\n",
-       "│                                                           │                             │                                            │                               │ scale: float32[64]            │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 128 (512 B)                   │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│ model/conv_out                                            │ Conv                        │ float32[1,128,128,64]                      │ bfloat16[1,128,128,3]         │ bias: float32[3]              │\n",
-       "│                                                           │                             │                                            │                               │ kernel: float32[3,3,64,3]     │\n",
-       "│                                                           │                             │                                            │                               │                               │\n",
-       "│                                                           │                             │                                            │                               │ 1,731 (6.9 KB)                │\n",
-       "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n",
-       "│                                                                                                                                                             Total  73,652,291 (294.6 MB)         │\n",
-       "└───────────────────────────────────────────────────────────┴─────────────────────────────┴────────────────────────────────────────────┴───────────────────────────────┴───────────────────────────────┘\n",
-       "                                                                                                                                                                                                        \n",
-       "                                                                                Total Parameters: 73,652,291 (294.6 MB)                                                                                 \n",
-       "
\n" - ], - "text/plain": [ - "\u001b[3m BCHWModelWrapper Summary \u001b[0m\n", - "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mpath \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmodule \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1minputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1moutputs \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mparams \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", - "│ │ BCHWModelWrapper │ temb: \u001b[2mfloat32\u001b[0m[1] │ \u001b[2mbfloat16\u001b[0m[1,128,128,3] │ │\n", - "│ │ │ textcontext: \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ x: \u001b[2mfloat32\u001b[0m[1,128,128,3] │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model │ FlaxUNet2DConditionModel │ encoder_hidden_states: \u001b[2mfloat32\u001b[0m[1,77,768] │ sample: \u001b[2mbfloat16\u001b[0m[1,3,128,128] │ │\n", - "│ │ │ sample: \u001b[2mfloat32\u001b[0m[1,3,128,128] │ │ │\n", - "│ │ │ timesteps: \u001b[2mfloat32\u001b[0m[1] │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/time_proj │ FlaxTimesteps │ \u001b[2mfloat32\u001b[0m[1] │ \u001b[2mfloat32\u001b[0m[1,64] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/time_embedding │ FlaxTimestepEmbedding │ \u001b[2mfloat32\u001b[0m[1,64] │ \u001b[2mbfloat16\u001b[0m[1,256] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/time_embedding/linear_1 │ Dense │ \u001b[2mfloat32\u001b[0m[1,64] │ \u001b[2mbfloat16\u001b[0m[1,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,640 \u001b[0m\u001b[1;2m(66.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/time_embedding/linear_2 │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/conv_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,3] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,3,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,792 \u001b[0m\u001b[1;2m(7.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0 │ FlaxCrossAttnDownBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ - \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ - - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - deterministic: True │ - \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m33,280 \u001b[0m\u001b[1;2m(133.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/transformer_blocks_0/dr… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_0/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m33,280 \u001b[0m\u001b[1;2m(133.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/transformer_blocks_0/dr… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/attentions_1/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/downsamplers_0 │ FlaxDownsample2D │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_0/downsamplers_0/conv │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1 │ FlaxCrossAttnDownBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ - \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ - - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - deterministic: True │ - \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ \u001b[2mfloat32\u001b[0m[1,64,64,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,64] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m73,856 \u001b[0m\u001b[1;2m(295.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m32,896 \u001b[0m\u001b[1;2m(131.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_0/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m8,320 \u001b[0m\u001b[1;2m(33.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,1024] │ bias: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,1024] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m132,096 \u001b[0m\u001b[1;2m(528.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,664 \u001b[0m\u001b[1;2m(262.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/transformer_blocks_0/dr… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_0/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m32,896 \u001b[0m\u001b[1;2m(131.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,1024] │ bias: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,1024] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m132,096 \u001b[0m\u001b[1;2m(528.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,664 \u001b[0m\u001b[1;2m(262.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/transformer_blocks_0/dr… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/attentions_1/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/downsamplers_0 │ FlaxDownsample2D │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_1/downsamplers_0/conv │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2 │ FlaxCrossAttnDownBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ - \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ - - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - deterministic: True │ - \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ \u001b[2mfloat32\u001b[0m[1,32,32,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,128] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m295,168 \u001b[0m\u001b[1;2m(1.2 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_0/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m33,024 \u001b[0m\u001b[1;2m(132.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,2048] │ bias: \u001b[2mfloat32\u001b[0m[2048] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,2048] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m526,336 \u001b[0m\u001b[1;2m(2.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1024,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,400 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/transformer_blocks_0/dr… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_0/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/at… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/no… │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,2048] │ bias: \u001b[2mfloat32\u001b[0m[2048] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,2048] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m526,336 \u001b[0m\u001b[1;2m(2.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/ff… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1024,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,400 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/transformer_blocks_0/dr… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/attentions_1/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/downsamplers_0 │ FlaxDownsample2D │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_2/downsamplers_0/conv │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3 │ FlaxDownBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ - - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - deterministic: True │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ \u001b[2mfloat32\u001b[0m[1,16,16,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,256] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,180,160 \u001b[0m\u001b[1;2m(4.7 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_0/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/down_blocks_3/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block │ FlaxUNetMidBlock2DCrossAttn │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,656 \u001b[0m\u001b[1;2m(1.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ kernel: \u001b[2mfloat32\u001b[0m[512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,144 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ kernel: \u001b[2mfloat32\u001b[0m[512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,144 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ kernel: \u001b[2mfloat32\u001b[0m[512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,144 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,656 \u001b[0m\u001b[1;2m(1.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn1/… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ kernel: \u001b[2mfloat32\u001b[0m[512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,144 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,512] │ kernel: \u001b[2mfloat32\u001b[0m[768,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m393,216 \u001b[0m\u001b[1;2m(1.6 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,512] │ kernel: \u001b[2mfloat32\u001b[0m[768,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m393,216 \u001b[0m\u001b[1;2m(1.6 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,656 \u001b[0m\u001b[1;2m(1.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/attn2/… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,2048] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,4096] │ bias: \u001b[2mfloat32\u001b[0m[4096] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,4096] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,101,248 \u001b[0m\u001b[1;2m(8.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,256,2048] │ \u001b[2mbfloat16\u001b[0m[1,256,2048] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/ff/net… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256,2048] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[2048,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,049,088 \u001b[0m\u001b[1;2m(4.2 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/transformer_blocks_0/dropou… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,256,512] │ \u001b[2mbfloat16\u001b[0m[1,256,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,656 \u001b[0m\u001b[1;2m(1.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/attentions_0/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/mid_block/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0 │ FlaxUpBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "│ │ │ res_hidden_states_tuple: │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,16,16,256] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │ │\n", - "│ │ │ temb: \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,1024] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,1024] │ \u001b[2mfloat32\u001b[0m[1,16,16,1024] │ bias: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,048 \u001b[0m\u001b[1;2m(8.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,1024] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,1024,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,719,104 \u001b[0m\u001b[1;2m(18.9 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_0/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,16,16,1024] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,1024,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m524,800 \u001b[0m\u001b[1;2m(2.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,1024] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,1024] │ \u001b[2mfloat32\u001b[0m[1,16,16,1024] │ bias: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,048 \u001b[0m\u001b[1;2m(8.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,1024] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,1024,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,719,104 \u001b[0m\u001b[1;2m(18.9 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_1/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,16,16,1024] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,1024,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m524,800 \u001b[0m\u001b[1;2m(2.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,16,16,768] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,768] │ \u001b[2mfloat32\u001b[0m[1,16,16,768] │ bias: \u001b[2mfloat32\u001b[0m[768] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[768] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,536 \u001b[0m\u001b[1;2m(6.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,768] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,768,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m3,539,456 \u001b[0m\u001b[1;2m(14.2 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,584 \u001b[0m\u001b[1;2m(526.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/resnets_2/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,16,16,768] │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,768,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m393,728 \u001b[0m\u001b[1;2m(1.6 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/upsamplers_0 │ FlaxUpsample2D │ \u001b[2mbfloat16\u001b[0m[1,16,16,512] │ \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_0/upsamplers_0/conv │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m2,359,808 \u001b[0m\u001b[1;2m(9.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1 │ FlaxCrossAttnUpBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "│ │ │ encoder_hidden_states: \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ res_hidden_states_tuple: │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,32,32,128] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │ │\n", - "│ │ │ temb: \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,32,32,768] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,768] │ \u001b[2mfloat32\u001b[0m[1,32,32,768] │ bias: \u001b[2mfloat32\u001b[0m[768] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[768] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,536 \u001b[0m\u001b[1;2m(6.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,768] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,769,728 \u001b[0m\u001b[1;2m(7.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_0/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,768] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,864 \u001b[0m\u001b[1;2m(787.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,2048] │ bias: \u001b[2mfloat32\u001b[0m[2048] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,2048] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m526,336 \u001b[0m\u001b[1;2m(2.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1024,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,400 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_0/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ \u001b[2mfloat32\u001b[0m[1,32,32,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,024 \u001b[0m\u001b[1;2m(4.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,512] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,512,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,179,904 \u001b[0m\u001b[1;2m(4.7 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_1/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,512] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,512,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m131,328 \u001b[0m\u001b[1;2m(525.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,2048] │ bias: \u001b[2mfloat32\u001b[0m[2048] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,2048] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m526,336 \u001b[0m\u001b[1;2m(2.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1024,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,400 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_1/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,32,32,384] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,384] │ \u001b[2mfloat32\u001b[0m[1,32,32,384] │ bias: \u001b[2mfloat32\u001b[0m[384] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[384] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m768 \u001b[0m\u001b[1;2m(3.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,384] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,384,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m884,992 \u001b[0m\u001b[1;2m(3.5 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/resnets_2/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,384] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,384,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,560 \u001b[0m\u001b[1;2m(394.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,536 \u001b[0m\u001b[1;2m(262.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,256] │ kernel: \u001b[2mfloat32\u001b[0m[768,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m196,608 \u001b[0m\u001b[1;2m(786.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,2048] │ bias: \u001b[2mfloat32\u001b[0m[2048] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,2048] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m526,336 \u001b[0m\u001b[1;2m(2.1 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,1024,1024] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1024,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m262,400 \u001b[0m\u001b[1;2m(1.0 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,1024,256] │ \u001b[2mbfloat16\u001b[0m[1,1024,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,792 \u001b[0m\u001b[1;2m(263.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/attentions_2/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/upsamplers_0 │ FlaxUpsample2D │ \u001b[2mbfloat16\u001b[0m[1,32,32,256] │ \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_1/upsamplers_0/conv │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m590,080 \u001b[0m\u001b[1;2m(2.4 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2 │ FlaxCrossAttnUpBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "│ │ │ encoder_hidden_states: \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ res_hidden_states_tuple: │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,64,64,64] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │ │\n", - "│ │ │ temb: \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,64,64,384] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,384] │ \u001b[2mfloat32\u001b[0m[1,64,64,384] │ bias: \u001b[2mfloat32\u001b[0m[384] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[384] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m768 \u001b[0m\u001b[1;2m(3.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,384] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,384,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m442,496 \u001b[0m\u001b[1;2m(1.8 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m32,896 \u001b[0m\u001b[1;2m(131.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_0/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,384] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,384,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,280 \u001b[0m\u001b[1;2m(197.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,1024] │ bias: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,1024] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m132,096 \u001b[0m\u001b[1;2m(528.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,664 \u001b[0m\u001b[1;2m(262.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_0/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ \u001b[2mfloat32\u001b[0m[1,64,64,256] │ bias: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[256] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m512 \u001b[0m\u001b[1;2m(2.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,256] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,256,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m295,040 \u001b[0m\u001b[1;2m(1.2 MB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m32,896 \u001b[0m\u001b[1;2m(131.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_1/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,256] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,256,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m32,896 \u001b[0m\u001b[1;2m(131.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,1024] │ bias: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,1024] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m132,096 \u001b[0m\u001b[1;2m(528.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,664 \u001b[0m\u001b[1;2m(262.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_1/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,64,64,192] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,192] │ \u001b[2mfloat32\u001b[0m[1,64,64,192] │ bias: \u001b[2mfloat32\u001b[0m[192] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[192] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m384 \u001b[0m\u001b[1;2m(1.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,192] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,192,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m221,312 \u001b[0m\u001b[1;2m(885.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m32,896 \u001b[0m\u001b[1;2m(131.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/resnets_2/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,192] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,192,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m24,704 \u001b[0m\u001b[1;2m(98.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,384 \u001b[0m\u001b[1;2m(65.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,128] │ kernel: \u001b[2mfloat32\u001b[0m[768,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m98,304 \u001b[0m\u001b[1;2m(393.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,1024] │ bias: \u001b[2mfloat32\u001b[0m[1024] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[128,1024] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m132,096 \u001b[0m\u001b[1;2m(528.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,4096,512] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[512,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m65,664 \u001b[0m\u001b[1;2m(262.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,4096,128] │ \u001b[2mbfloat16\u001b[0m[1,4096,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,512 \u001b[0m\u001b[1;2m(66.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/attentions_2/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/upsamplers_0 │ FlaxUpsample2D │ \u001b[2mbfloat16\u001b[0m[1,64,64,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_2/upsamplers_0/conv │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m147,584 \u001b[0m\u001b[1;2m(590.3 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3 │ FlaxCrossAttnUpBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "│ │ │ encoder_hidden_states: \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ res_hidden_states_tuple: │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │ │\n", - "│ │ │ temb: \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,128,128,192] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,192] │ \u001b[2mfloat32\u001b[0m[1,128,128,192] │ bias: \u001b[2mfloat32\u001b[0m[192] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[192] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m384 \u001b[0m\u001b[1;2m(1.5 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,192] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,192,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m110,656 \u001b[0m\u001b[1;2m(442.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_0/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,192] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,192,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m12,352 \u001b[0m\u001b[1;2m(49.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m33,280 \u001b[0m\u001b[1;2m(133.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_0/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mfloat32\u001b[0m[1,128,128,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m73,792 \u001b[0m\u001b[1;2m(295.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_1/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m8,256 \u001b[0m\u001b[1;2m(33.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m33,280 \u001b[0m\u001b[1;2m(133.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_1/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2 │ FlaxResnetBlock2D │ - \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mbfloat16\u001b[0m[1,256] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2/norm1 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mfloat32\u001b[0m[1,128,128,128] │ bias: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[128] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m256 \u001b[0m\u001b[1;2m(1.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2/conv1 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,128,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m73,792 \u001b[0m\u001b[1;2m(295.2 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2/time_emb_proj │ Dense │ \u001b[2mbfloat16\u001b[0m[1,256] │ \u001b[2mbfloat16\u001b[0m[1,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2/norm2 │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2/dropout │ Dropout │ - \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2/conv2 │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m36,928 \u001b[0m\u001b[1;2m(147.7 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/resnets_2/conv_shortcut │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,128] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,128,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m8,256 \u001b[0m\u001b[1;2m(33.0 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2 │ FlaxTransformer2DModel │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/norm │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/proj_in │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0 │ FlaxBasicTransformerBlock │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/norm1 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn1 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/norm2 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn2 │ FlaxAttention │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - \u001b[2mfloat32\u001b[0m[1,77,768] │ │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,096 \u001b[0m\u001b[1;2m(16.4 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mfloat32\u001b[0m[1,77,768] │ \u001b[2mbfloat16\u001b[0m[1,77,64] │ kernel: \u001b[2mfloat32\u001b[0m[768,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m49,152 \u001b[0m\u001b[1;2m(196.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/attn… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/norm3 │ LayerNorm │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff │ FlaxFeedForward │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ FlaxGEGLU │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,512] │ bias: \u001b[2mfloat32\u001b[0m[512] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[64,512] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m33,280 \u001b[0m\u001b[1;2m(133.1 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/ff/n… │ Dense │ \u001b[2mbfloat16\u001b[0m[1,16384,256] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[256,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m16,448 \u001b[0m\u001b[1;2m(65.8 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/transformer_blocks_0/drop… │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,16384,64] │ \u001b[2mbfloat16\u001b[0m[1,16384,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/proj_out │ Conv │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[1,1,64,64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m4,160 \u001b[0m\u001b[1;2m(16.6 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/up_blocks_3/attentions_2/dropout_layer │ Dropout │ - \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ │\n", - "│ │ │ - deterministic: True │ │ │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/conv_norm_out │ GroupNorm │ \u001b[2mbfloat16\u001b[0m[1,128,128,64] │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ bias: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ scale: \u001b[2mfloat32\u001b[0m[64] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m128 \u001b[0m\u001b[1;2m(512 B)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│ model/conv_out │ Conv │ \u001b[2mfloat32\u001b[0m[1,128,128,64] │ \u001b[2mbfloat16\u001b[0m[1,128,128,3] │ bias: \u001b[2mfloat32\u001b[0m[3] │\n", - "│ │ │ │ │ kernel: \u001b[2mfloat32\u001b[0m[3,3,64,3] │\n", - "│ │ │ │ │ │\n", - "│ │ │ │ │ \u001b[1m1,731 \u001b[0m\u001b[1;2m(6.9 KB)\u001b[0m │\n", - "├───────────────────────────────────────────────────────────┼─────────────────────────────┼────────────────────────────────────────────┼───────────────────────────────┼───────────────────────────────┤\n", - "│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m Total\u001b[0m\u001b[1m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m73,652,291 \u001b[0m\u001b[1;2m(294.6 MB)\u001b[0m\u001b[1m \u001b[0m\u001b[1m \u001b[0m│\n", - "└───────────────────────────────────────────────────────────┴─────────────────────────────┴────────────────────────────────────────────┴───────────────────────────────┴───────────────────────────────┘\n", - "\u001b[1m \u001b[0m\n", - "\u001b[1m Total Parameters: 73,652,291 \u001b[0m\u001b[1;2m(294.6 MB)\u001b[0m\u001b[1m \u001b[0m\n" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n", - "\n" - ] - } - ], - "source": [ - "trainer.summary()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ones = trainer.get_input_ones()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "out = trainer.model.apply(\n", - " trainer.state.params,\n", - " **ones,\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'FlaxUNet2DConditionOutput' object has no attribute 'shape'", - "output_type": "error", - "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[16]\u001b[39m\u001b[32m, line 1\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m \u001b[43mout\u001b[49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m\n", - "\u001b[31mAttributeError\u001b[39m: 'FlaxUNet2DConditionOutput' object has no attribute 'shape'" - ] - } - ], - "source": [ - "out." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, "outputs": [ { "name": "stdout", @@ -5821,114 +311,13 @@ "name": "stderr", "output_type": "stream", "text": [ - " 0%| | 0/200 [00:00 \u001b[39m\u001b[32m2\u001b[39m final_state = \u001b[43mtrainer\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatches\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msampler_class\u001b[49m\u001b[43m=\u001b[49m\u001b[43mEulerAncestralSampler\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msampling_noise_schedule\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkaras_ve_schedule\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/persist/FlaxDiff/flaxdiff/trainer/diffusion_trainer.py:354\u001b[39m, in \u001b[36mDiffusionTrainer.fit\u001b[39m\u001b[34m(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch, sampler_class, sampling_noise_schedule)\u001b[39m\n\u001b[32m 349\u001b[39m local_batch_size = data[\u001b[33m'\u001b[39m\u001b[33mlocal_batch_size\u001b[39m\u001b[33m'\u001b[39m]\n\u001b[32m 350\u001b[39m validation_step_args = {\n\u001b[32m 351\u001b[39m \u001b[33m\"\u001b[39m\u001b[33msampler_class\u001b[39m\u001b[33m\"\u001b[39m: sampler_class,\n\u001b[32m 352\u001b[39m \u001b[33m\"\u001b[39m\u001b[33msampling_noise_schedule\u001b[39m\u001b[33m\"\u001b[39m: sampling_noise_schedule,\n\u001b[32m 353\u001b[39m }\n\u001b[32m--> \u001b[39m\u001b[32m354\u001b[39m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m.\u001b[49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 355\u001b[39m \u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[32m 356\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_steps_per_epoch\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtraining_steps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[32m 357\u001b[39m \u001b[43m \u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[32m 358\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_step_args\u001b[49m\u001b[43m=\u001b[49m\u001b[43m{\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mbatch_size\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mlocal_batch_size\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\n\u001b[32m 359\u001b[39m \u001b[43m \u001b[49m\u001b[43mval_steps_per_epoch\u001b[49m\u001b[43m=\u001b[49m\u001b[43mval_steps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 360\u001b[39m \u001b[43m \u001b[49m\u001b[43mvalidation_step_args\u001b[49m\u001b[43m=\u001b[49m\u001b[43mvalidation_step_args\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 361\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/persist/FlaxDiff/flaxdiff/trainer/simple_trainer.py:498\u001b[39m, in \u001b[36mSimpleTrainer.fit\u001b[39m\u001b[34m(self, data, train_steps_per_epoch, epochs, train_step_args, val_steps_per_epoch, validation_step_args)\u001b[39m\n\u001b[32m 495\u001b[39m start_time = time.time()\n\u001b[32m 496\u001b[39m epoch_loss = \u001b[32m0\u001b[39m\n\u001b[32m--> \u001b[39m\u001b[32m498\u001b[39m epoch_loss, current_step, train_state, rng_state = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtrain_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 499\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 500\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_step\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 501\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_ds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 502\u001b[39m \u001b[43m \u001b[49m\u001b[43mtrain_steps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 503\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mlatest_step\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 504\u001b[39m \u001b[43m \u001b[49m\u001b[43mrng_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 505\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 506\u001b[39m \u001b[38;5;28mprint\u001b[39m(colored(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch done on process index \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mprocess_index\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m, PROCESS_COLOR_MAP[process_index]))\n\u001b[32m 508\u001b[39m \u001b[38;5;28mself\u001b[39m.latest_step = current_step\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/persist/FlaxDiff/flaxdiff/trainer/simple_trainer.py:427\u001b[39m, in \u001b[36mSimpleTrainer.train_loop\u001b[39m\u001b[34m(self, train_state, train_step_fn, train_ds, train_steps_per_epoch, current_step, rng_state)\u001b[39m\n\u001b[32m 424\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.distributed_training \u001b[38;5;129;01mand\u001b[39;00m global_device_count > \u001b[32m1\u001b[39m:\n\u001b[32m 425\u001b[39m \u001b[38;5;66;03m# # Convert the local device batches to a unified global jax.Array \u001b[39;00m\n\u001b[32m 426\u001b[39m batch = convert_to_global_tree(\u001b[38;5;28mself\u001b[39m.mesh, batch)\n\u001b[32m--> \u001b[39m\u001b[32m427\u001b[39m train_state, loss, rng_state = \u001b[43mtrain_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrng_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mglobal_device_indexes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 429\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m i == \u001b[32m0\u001b[39m:\n\u001b[32m 430\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mTraining started for process index \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mprocess_index\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m at step \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcurrent_step\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n", - " \u001b[31m[... skipping hidden 28 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/persist/FlaxDiff/flaxdiff/trainer/diffusion_trainer.py:200\u001b[39m, in \u001b[36mDiffusionTrainer._define_train_step..train_step\u001b[39m\u001b[34m(train_state, rng_state, batch, local_device_index)\u001b[39m\n\u001b[32m 198\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 199\u001b[39m grad_fn = jax.value_and_grad(model_loss)\n\u001b[32m--> \u001b[39m\u001b[32m200\u001b[39m loss, grads = \u001b[43mgrad_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_state\u001b[49m\u001b[43m.\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 201\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m distributed_training:\n\u001b[32m 202\u001b[39m grads = jax.lax.pmean(grads, \u001b[33m\"\u001b[39m\u001b[33mdata\u001b[39m\u001b[33m\"\u001b[39m)\n", - " \u001b[31m[... skipping hidden 16 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/persist/FlaxDiff/flaxdiff/trainer/diffusion_trainer.py:181\u001b[39m, in \u001b[36mDiffusionTrainer._define_train_step..train_step..model_loss\u001b[39m\u001b[34m(params)\u001b[39m\n\u001b[32m 180\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mmodel_loss\u001b[39m(params):\n\u001b[32m--> \u001b[39m\u001b[32m181\u001b[39m preds = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mapply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43mnoise_schedule\u001b[49m\u001b[43m.\u001b[49m\u001b[43mtransform_inputs\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnoisy_images\u001b[49m\u001b[43m*\u001b[49m\u001b[43mc_in\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnoise_level\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlabel_seq\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 182\u001b[39m preds = model_output_transform.pred_transform(noisy_images, preds, rates)\n\u001b[32m 183\u001b[39m nloss = loss_fn(preds, expected_output)\n", - " \u001b[31m[... skipping hidden 6 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/diffusers/models/unets/unet_2d_condition_flax.py:401\u001b[39m, in \u001b[36mFlaxUNet2DConditionModel.__call__\u001b[39m\u001b[34m(self, sample, timesteps, encoder_hidden_states, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, return_dict, train)\u001b[39m\n\u001b[32m 399\u001b[39m \u001b[38;5;66;03m# 2. pre-process\u001b[39;00m\n\u001b[32m 400\u001b[39m sample = jnp.transpose(sample, (\u001b[32m0\u001b[39m, \u001b[32m2\u001b[39m, \u001b[32m3\u001b[39m, \u001b[32m1\u001b[39m))\n\u001b[32m--> \u001b[39m\u001b[32m401\u001b[39m sample = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mconv_in\u001b[49m\u001b[43m(\u001b[49m\u001b[43msample\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 403\u001b[39m \u001b[38;5;66;03m# 3. down\u001b[39;00m\n\u001b[32m 404\u001b[39m down_block_res_samples = (sample,)\n", - " \u001b[31m[... skipping hidden 2 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/flax/linen/linear.py:662\u001b[39m, in \u001b[36m_Conv.__call__\u001b[39m\u001b[34m(self, inputs)\u001b[39m\n\u001b[32m 656\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m.mask.shape != kernel_shape:\n\u001b[32m 657\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[32m 658\u001b[39m \u001b[33m'\u001b[39m\u001b[33mMask needs to have the same shape as weights. \u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 659\u001b[39m \u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[33mShapes are: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m.mask.shape\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m, \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mkernel_shape\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m'\u001b[39m\n\u001b[32m 660\u001b[39m )\n\u001b[32m--> \u001b[39m\u001b[32m662\u001b[39m kernel = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mparam\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 663\u001b[39m \u001b[43m \u001b[49m\u001b[33;43m'\u001b[39;49m\u001b[33;43mkernel\u001b[39;49m\u001b[33;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mkernel_init\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkernel_shape\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mparam_dtype\u001b[49m\n\u001b[32m 664\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 666\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.mask \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 667\u001b[39m kernel *= \u001b[38;5;28mself\u001b[39m.mask\n", - " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/flax/core/scope.py:960\u001b[39m, in \u001b[36mScope.param\u001b[39m\u001b[34m(self, name, init_fn, unbox, *init_args, **init_kwargs)\u001b[39m\n\u001b[32m 955\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m val, abs_val \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(value_flat, abs_value_flat):\n\u001b[32m 956\u001b[39m \u001b[38;5;66;03m# NOTE: We could check dtype consistency here as well but it's\u001b[39;00m\n\u001b[32m 957\u001b[39m \u001b[38;5;66;03m# usefuleness is less obvious. We might intentionally change the dtype\u001b[39;00m\n\u001b[32m 958\u001b[39m \u001b[38;5;66;03m# for inference to a half float type for example.\u001b[39;00m\n\u001b[32m 959\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m np.shape(val) != np.shape(abs_val):\n\u001b[32m--> \u001b[39m\u001b[32m960\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m errors.ScopeParamShapeError(\n\u001b[32m 961\u001b[39m name, \u001b[38;5;28mself\u001b[39m.path_text, np.shape(abs_val), np.shape(val)\n\u001b[32m 962\u001b[39m )\n\u001b[32m 963\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 964\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_mutable_collection(\u001b[33m'\u001b[39m\u001b[33mparams\u001b[39m\u001b[33m'\u001b[39m):\n", - "\u001b[31mScopeParamShapeError\u001b[39m: Initializer expected to generate shape (3, 3, 3, 64) but got shape (3, 3, 128, 64) instead for parameter \"kernel\" in \"/conv_in\". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)" + "First batch loaded at step 0\n", + "Model Input with shape (4, 128, 128, 3) has dtype float32\n", + "Model Input with shape (4,) has dtype float32\n", + "Model Input with shape (4, 77, 768) has dtype float32\n", + "Model Output with shape (4, 128, 128, 3) has dtype bfloat16\n", + "Loss with shape (4, 128, 128, 3) has dtype float32\n", + "Final Loss with shape () has dtype float32\n" ] } ], @@ -5978,13 +349,6 @@ "final_state = trainer.fit(data, batches, epochs=2, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, diff --git a/pyproject.toml b/pyproject.toml index 593751d..e347b0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "flaxdiff" -version = "0.1.38" +version = "0.1.38.1" description = "A versatile and easy to understand Diffusion library" readme = "README.md" authors = [ diff --git a/training.py b/training.py index b958161..101bee4 100644 --- a/training.py +++ b/training.py @@ -1,3 +1,6 @@ +import jax +jax.config.update("jax_enable_x64", True) + from typing import Any, Tuple, Mapping, Callable, List, Dict from functools import partial import flax.training.dynamic_scale @@ -13,7 +16,6 @@ import struct as st import flax import tqdm -import jax import jax.numpy as jnp import optax