1919
2020from ...configuration_utils import ConfigMixin , register_to_config
2121from ...utils import BaseOutput
22- from ..embeddings import GaussianFourierProjection , TimestepEmbedding , Timesteps
22+ from ..embeddings import GaussianFourierProjection , TimestepEmbedding , Timesteps , TimestepsADM
2323from ..modeling_utils import ModelMixin
24- from .unet_2d_blocks import UNetMidBlock2D , get_down_block , get_up_block
24+ from .unet_2d_blocks import UNetMidBlock2D , UNetMidBlock2DADM , get_down_block , get_up_block
2525
2626
2727@dataclass
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
5858 down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
5959 Tuple of downsample block types.
6060 mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61- Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D `.
61+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UNetMidBlock2DADM `.
6262 up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
6363 Tuple of upsample block types.
6464 block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -72,6 +72,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
7272 The upsample type for upsampling layers. Choose between "conv" and "resnet"
7373 dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
7474 act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75+ attention_type (`str`, *optional*, defaults to `default`): The attention type, Choose between "default", "adm"
7576 attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
7677 norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
7778 attn_norm_num_groups (`int`, *optional*, defaults to `None`):
@@ -100,6 +101,7 @@ def __init__(
100101 freq_shift : int = 0 ,
101102 flip_sin_to_cos : bool = True ,
102103 down_block_types : Tuple [str ] = ("DownBlock2D" , "AttnDownBlock2D" , "AttnDownBlock2D" , "AttnDownBlock2D" ),
104+ mid_block_type : str = "UNetMidBlock2D" ,
103105 up_block_types : Tuple [str ] = ("AttnUpBlock2D" , "AttnUpBlock2D" , "AttnUpBlock2D" , "UpBlock2D" ),
104106 block_out_channels : Tuple [int ] = (224 , 448 , 672 , 896 ),
105107 layers_per_block : int = 2 ,
@@ -109,6 +111,7 @@ def __init__(
109111 upsample_type : str = "conv" ,
110112 dropout : float = 0.0 ,
111113 act_fn : str = "silu" ,
114+ attention_type : str = "default" ,
112115 attention_head_dim : Optional [int ] = 8 ,
113116 norm_num_groups : int = 32 ,
114117 attn_norm_num_groups : Optional [int ] = None ,
@@ -148,7 +151,9 @@ def __init__(
148151 elif time_embedding_type == "learned" :
149152 self .time_proj = nn .Embedding (num_train_timesteps , block_out_channels [0 ])
150153 timestep_input_dim = block_out_channels [0 ]
151-
154+ elif time_embedding_type == "adm" :
155+ self .time_proj = TimestepsADM (block_out_channels [0 ])
156+ timestep_input_dim = block_out_channels [0 ]
152157 self .time_embedding = TimestepEmbedding (timestep_input_dim , time_embed_dim )
153158
154159 # class embedding
@@ -182,6 +187,7 @@ def __init__(
182187 resnet_eps = norm_eps ,
183188 resnet_act_fn = act_fn ,
184189 resnet_groups = norm_num_groups ,
190+ attention_type = attention_type ,
185191 attention_head_dim = attention_head_dim if attention_head_dim is not None else output_channel ,
186192 downsample_padding = downsample_padding ,
187193 resnet_time_scale_shift = resnet_time_scale_shift ,
@@ -191,20 +197,34 @@ def __init__(
191197 self .down_blocks .append (down_block )
192198
193199 # mid
194- self .mid_block = UNetMidBlock2D (
195- in_channels = block_out_channels [- 1 ],
196- temb_channels = time_embed_dim ,
197- dropout = dropout ,
198- resnet_eps = norm_eps ,
199- resnet_act_fn = act_fn ,
200- output_scale_factor = mid_block_scale_factor ,
201- resnet_time_scale_shift = resnet_time_scale_shift ,
202- attention_head_dim = attention_head_dim if attention_head_dim is not None else block_out_channels [- 1 ],
203- resnet_groups = norm_num_groups ,
204- attn_groups = attn_norm_num_groups ,
205- add_attention = add_attention ,
206- )
207-
200+ if mid_block_type == "UNetMidBlock2D" :
201+ self .mid_block = UNetMidBlock2D (
202+ in_channels = block_out_channels [- 1 ],
203+ temb_channels = time_embed_dim ,
204+ dropout = dropout ,
205+ resnet_eps = norm_eps ,
206+ resnet_act_fn = act_fn ,
207+ output_scale_factor = mid_block_scale_factor ,
208+ resnet_time_scale_shift = resnet_time_scale_shift ,
209+ attention_head_dim = attention_head_dim if attention_head_dim is not None else block_out_channels [- 1 ],
210+ resnet_groups = norm_num_groups ,
211+ attn_groups = attn_norm_num_groups ,
212+ add_attention = add_attention ,
213+ )
214+ elif mid_block_type == "UNetMidBlock2DADM" :
215+ self .mid_block = UNetMidBlock2DADM (
216+ in_channels = block_out_channels [- 1 ],
217+ temb_channels = time_embed_dim ,
218+ dropout = dropout ,
219+ resnet_eps = norm_eps ,
220+ resnet_act_fn = act_fn ,
221+ output_scale_factor = mid_block_scale_factor ,
222+ resnet_time_scale_shift = resnet_time_scale_shift ,
223+ attention_head_dim = attention_head_dim if attention_head_dim is not None else block_out_channels [- 1 ],
224+ resnet_groups = norm_num_groups ,
225+ )
226+ else :
227+ raise ValueError
208228 # up
209229 reversed_block_out_channels = list (reversed (block_out_channels ))
210230 output_channel = reversed_block_out_channels [0 ]
@@ -214,7 +234,6 @@ def __init__(
214234 input_channel = reversed_block_out_channels [min (i + 1 , len (block_out_channels ) - 1 )]
215235
216236 is_final_block = i == len (block_out_channels ) - 1
217-
218237 up_block = get_up_block (
219238 up_block_type ,
220239 num_layers = layers_per_block + 1 ,
@@ -226,6 +245,7 @@ def __init__(
226245 resnet_eps = norm_eps ,
227246 resnet_act_fn = act_fn ,
228247 resnet_groups = norm_num_groups ,
248+ attention_type = attention_type ,
229249 attention_head_dim = attention_head_dim if attention_head_dim is not None else output_channel ,
230250 resnet_time_scale_shift = resnet_time_scale_shift ,
231251 upsample_type = upsample_type ,
0 commit comments