@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
5858 down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
5959 Tuple of downsample block types.
6060 mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61- Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D `.
61+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None `.
6262 up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
6363 Tuple of upsample block types.
6464 block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -103,6 +103,7 @@ def __init__(
103103 freq_shift : int = 0 ,
104104 flip_sin_to_cos : bool = True ,
105105 down_block_types : Tuple [str , ...] = ("DownBlock2D" , "AttnDownBlock2D" , "AttnDownBlock2D" , "AttnDownBlock2D" ),
106+ mid_block_type : Optional [str ] = "UNetMidBlock2D" ,
106107 up_block_types : Tuple [str , ...] = ("AttnUpBlock2D" , "AttnUpBlock2D" , "AttnUpBlock2D" , "UpBlock2D" ),
107108 block_out_channels : Tuple [int , ...] = (224 , 448 , 672 , 896 ),
108109 layers_per_block : int = 2 ,
@@ -194,19 +195,22 @@ def __init__(
194195 self .down_blocks .append (down_block )
195196
196197 # mid
197- self .mid_block = UNetMidBlock2D (
198- in_channels = block_out_channels [- 1 ],
199- temb_channels = time_embed_dim ,
200- dropout = dropout ,
201- resnet_eps = norm_eps ,
202- resnet_act_fn = act_fn ,
203- output_scale_factor = mid_block_scale_factor ,
204- resnet_time_scale_shift = resnet_time_scale_shift ,
205- attention_head_dim = attention_head_dim if attention_head_dim is not None else block_out_channels [- 1 ],
206- resnet_groups = norm_num_groups ,
207- attn_groups = attn_norm_num_groups ,
208- add_attention = add_attention ,
209- )
198+ if mid_block_type is None :
199+ self .mid_block = None
200+ else :
201+ self .mid_block = UNetMidBlock2D (
202+ in_channels = block_out_channels [- 1 ],
203+ temb_channels = time_embed_dim ,
204+ dropout = dropout ,
205+ resnet_eps = norm_eps ,
206+ resnet_act_fn = act_fn ,
207+ output_scale_factor = mid_block_scale_factor ,
208+ resnet_time_scale_shift = resnet_time_scale_shift ,
209+ attention_head_dim = attention_head_dim if attention_head_dim is not None else block_out_channels [- 1 ],
210+ resnet_groups = norm_num_groups ,
211+ attn_groups = attn_norm_num_groups ,
212+ add_attention = add_attention ,
213+ )
210214
211215 # up
212216 reversed_block_out_channels = list (reversed (block_out_channels ))
@@ -322,7 +326,8 @@ def forward(
322326 down_block_res_samples += res_samples
323327
324328 # 4. mid
325- sample = self .mid_block (sample , emb )
329+ if self .mid_block is not None :
330+ sample = self .mid_block (sample , emb )
326331
327332 # 5. up
328333 skip_sample = None
0 commit comments