1919
2020from ...configuration_utils import ConfigMixin , register_to_config
2121from ...utils import BaseOutput
22- from ..embeddings import GaussianFourierProjection , TimestepEmbedding , Timesteps
22+ from ..embeddings import GaussianFourierProjection , TimestepEmbedding , Timesteps , TimestepsADM
2323from ..modeling_utils import ModelMixin
2424from .unet_2d_blocks import UNetMidBlock2D , get_down_block , get_up_block
2525
@@ -72,6 +72,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
7272 The upsample type for upsampling layers. Choose between "conv" and "resnet"
7373 dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
7474 act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75+ attention_legacy_order (`bool`, *optional*, defaults to `False`):
76+ if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
7577 attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
7678 norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
7779 attn_norm_num_groups (`int`, *optional*, defaults to `None`):
@@ -109,6 +111,7 @@ def __init__(
109111 upsample_type : str = "conv" ,
110112 dropout : float = 0.0 ,
111113 act_fn : str = "silu" ,
114+ attention_legacy_order : bool = False ,
112115 attention_head_dim : Optional [int ] = 8 ,
113116 norm_num_groups : int = 32 ,
114117 attn_norm_num_groups : Optional [int ] = None ,
@@ -148,7 +151,9 @@ def __init__(
148151 elif time_embedding_type == "learned" :
149152 self .time_proj = nn .Embedding (num_train_timesteps , block_out_channels [0 ])
150153 timestep_input_dim = block_out_channels [0 ]
151-
154+ elif time_embedding_type == "adm" :
155+ self .time_proj = TimestepsADM (block_out_channels [0 ])
156+ timestep_input_dim = block_out_channels [0 ]
152157 self .time_embedding = TimestepEmbedding (timestep_input_dim , time_embed_dim )
153158
154159 # class embedding
@@ -182,6 +187,7 @@ def __init__(
182187 resnet_eps = norm_eps ,
183188 resnet_act_fn = act_fn ,
184189 resnet_groups = norm_num_groups ,
190+ attention_legacy_order = attention_legacy_order ,
185191 attention_head_dim = attention_head_dim if attention_head_dim is not None else output_channel ,
186192 downsample_padding = downsample_padding ,
187193 resnet_time_scale_shift = resnet_time_scale_shift ,
@@ -191,6 +197,7 @@ def __init__(
191197 self .down_blocks .append (down_block )
192198
193199 # mid
200+ attn_norm_num_groups = norm_num_groups if attention_legacy_order is True else attn_norm_num_groups
194201 self .mid_block = UNetMidBlock2D (
195202 in_channels = block_out_channels [- 1 ],
196203 temb_channels = time_embed_dim ,
@@ -203,6 +210,7 @@ def __init__(
203210 resnet_groups = norm_num_groups ,
204211 attn_groups = attn_norm_num_groups ,
205212 add_attention = add_attention ,
213+ attention_legacy_order = attention_legacy_order ,
206214 )
207215
208216 # up
@@ -226,6 +234,7 @@ def __init__(
226234 resnet_eps = norm_eps ,
227235 resnet_act_fn = act_fn ,
228236 resnet_groups = norm_num_groups ,
237+ attention_legacy_order = attention_legacy_order ,
229238 attention_head_dim = attention_head_dim if attention_head_dim is not None else output_channel ,
230239 resnet_time_scale_shift = resnet_time_scale_shift ,
231240 upsample_type = upsample_type ,
0 commit comments