Skip to content

FlaxUNet2DConditionModel is not initialized with correct dtypesΒ #11144

@wittenator

Description

@wittenator

Describe the bug

The FlaxUNet2DConditionModel allows specifying the dtype of the weights. Supplying a dtype different from float32 does not seem to be propagated to the actual model. This is imo different from #2068 since the afaik the code has correct dtype initialization. but the result is still incorrect. So this is not connected to loading FP32 weights or something similar.

Reproduction

import diffusers
from jax import random, numpy as jnp

dummy_input = jnp.zeros((2, 4, 32, 32), dtype=jnp.bfloat16)
dummy_t = jnp.zeros(2, dtype=jnp.bfloat16)
model = diffusers.FlaxUNet2DConditionModel(dtype=jnp.bfloat16)
key1, key2 = random.split(random.key(0))
params = model.init(key1, dummy_input, dummy_t, None)
print(jax.tree_util.tree_map(jnp.dtype, params))

Logs

{'params': {'conv_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_norm_out': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'conv_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'down_blocks_0': {'attentions_0': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_1': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'downsamplers_0': {'conv': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'down_blocks_1': {'attentions_0': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_1': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'downsamplers_0': {'conv': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'down_blocks_2': {'attentions_0': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_1': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'downsamplers_0': {'conv': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'down_blocks_3': {'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'mid_block': {'attentions_0': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'time_embedding': {'linear_1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'linear_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'up_blocks_0': {'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_2': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'upsamplers_0': {'conv': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'up_blocks_1': {'attentions_0': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_1': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_2': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_2': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'upsamplers_0': {'conv': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'up_blocks_2': {'attentions_0': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_1': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_2': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_2': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'upsamplers_0': {'conv': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}, 'up_blocks_3': {'attentions_0': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_1': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'attentions_2': {'norm': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'proj_in': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'proj_out': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'transformer_blocks_0': {'attn1': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'attn2': {'to_k': {'kernel': dtype('float32')}, 'to_out_0': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'to_q': {'kernel': dtype('float32')}, 'to_v': {'kernel': dtype('float32')}}, 'ff': {'net_0': {'proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'net_2': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm3': {'bias': dtype('float32'), 'scale': dtype('float32')}}}, 'resnets_0': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_1': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}, 'resnets_2': {'conv1': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv2': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'conv_shortcut': {'bias': dtype('float32'), 'kernel': dtype('float32')}, 'norm1': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'norm2': {'bias': dtype('float32'), 'scale': dtype('float32')}, 'time_emb_proj': {'bias': dtype('float32'), 'kernel': dtype('float32')}}}}}

System Info

  • πŸ€— Diffusers version: 0.32.2
  • Platform: Linux-6.1.85+-x86_64-with-glibc2.35
  • Running on Google Colab?: Yes
  • Python version: 3.11.11
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): 0.10.4 (gpu)
  • Jax version: 0.5.2
  • JaxLib version: 0.5.1
  • Huggingface_hub version: 0.29.3
  • Transformers version: 4.49.0
  • Accelerate version: 1.5.2
  • PEFT version: 0.14.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: Tesla T4, 15360 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions