1919
2020from ...configuration_utils import ConfigMixin , register_to_config
2121from ...utils import BaseOutput
22- from ..embeddings import GaussianFourierProjection , TimestepEmbedding , Timesteps
22+ from ..embeddings import GaussianFourierProjection , TimestepEmbedding , Timesteps , TimestepsADM
2323from ..modeling_utils import ModelMixin
2424from .unet_2d_blocks import UNetMidBlock2D , get_down_block , get_up_block
2525
@@ -72,6 +72,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
7272 The upsample type for upsampling layers. Choose between "conv" and "resnet"
7373 dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
7474 act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75+ attention_legacy_order (`bool`, *optional*, defaults to `False`):
76+ if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
7577 attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
7678 norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
7779 attn_norm_num_groups (`int`, *optional*, defaults to `None`):
@@ -109,6 +111,7 @@ def __init__(
109111 upsample_type : str = "conv" ,
110112 dropout : float = 0.0 ,
111113 act_fn : str = "silu" ,
114+ attention_legacy_order : bool = False ,
112115 attention_head_dim : Optional [int ] = 8 ,
113116 norm_num_groups : int = 32 ,
114117 attn_norm_num_groups : Optional [int ] = None ,
@@ -148,7 +151,9 @@ def __init__(
148151 elif time_embedding_type == "learned" :
149152 self .time_proj = nn .Embedding (num_train_timesteps , block_out_channels [0 ])
150153 timestep_input_dim = block_out_channels [0 ]
151-
154+ elif time_embedding_type == "adm" :
155+ self .time_proj = TimestepsADM (block_out_channels [0 ])
156+ timestep_input_dim = block_out_channels [0 ]
152157 self .time_embedding = TimestepEmbedding (timestep_input_dim , time_embed_dim )
153158
154159 # class embedding
@@ -182,6 +187,7 @@ def __init__(
182187 resnet_eps = norm_eps ,
183188 resnet_act_fn = act_fn ,
184189 resnet_groups = norm_num_groups ,
190+ attention_legacy_order = attention_legacy_order ,
185191 attention_head_dim = attention_head_dim if attention_head_dim is not None else output_channel ,
186192 downsample_padding = downsample_padding ,
187193 resnet_time_scale_shift = resnet_time_scale_shift ,
@@ -203,6 +209,7 @@ def __init__(
203209 resnet_groups = norm_num_groups ,
204210 attn_groups = attn_norm_num_groups ,
205211 add_attention = add_attention ,
212+ attention_legacy_order = attention_legacy_order ,
206213 )
207214
208215 # up
@@ -226,6 +233,7 @@ def __init__(
226233 resnet_eps = norm_eps ,
227234 resnet_act_fn = act_fn ,
228235 resnet_groups = norm_num_groups ,
236+ attention_legacy_order = attention_legacy_order ,
229237 attention_head_dim = attention_head_dim if attention_head_dim is not None else output_channel ,
230238 resnet_time_scale_shift = resnet_time_scale_shift ,
231239 upsample_type = upsample_type ,
0 commit comments