@@ -249,6 +249,81 @@ def get_down_block(
249249    raise  ValueError (f"{ down_block_type }   does not exist." )
250250
251251
252+ def  get_mid_block (
253+     mid_block_type : str ,
254+     temb_channels : int ,
255+     in_channels : int ,
256+     resnet_eps : float ,
257+     resnet_act_fn : str ,
258+     resnet_groups : int ,
259+     output_scale_factor : float  =  1.0 ,
260+     transformer_layers_per_block : int  =  1 ,
261+     num_attention_heads : Optional [int ] =  None ,
262+     cross_attention_dim : Optional [int ] =  None ,
263+     dual_cross_attention : bool  =  False ,
264+     use_linear_projection : bool  =  False ,
265+     mid_block_only_cross_attention : bool  =  False ,
266+     upcast_attention : bool  =  False ,
267+     resnet_time_scale_shift : str  =  "default" ,
268+     attention_type : str  =  "default" ,
269+     resnet_skip_time_act : bool  =  False ,
270+     cross_attention_norm : Optional [str ] =  None ,
271+     attention_head_dim : Optional [int ] =  1 ,
272+     dropout : float  =  0.0 ,
273+ ):
274+     if  mid_block_type  ==  "UNetMidBlock2DCrossAttn" :
275+         return  UNetMidBlock2DCrossAttn (
276+             transformer_layers_per_block = transformer_layers_per_block ,
277+             in_channels = in_channels ,
278+             temb_channels = temb_channels ,
279+             dropout = dropout ,
280+             resnet_eps = resnet_eps ,
281+             resnet_act_fn = resnet_act_fn ,
282+             output_scale_factor = output_scale_factor ,
283+             resnet_time_scale_shift = resnet_time_scale_shift ,
284+             cross_attention_dim = cross_attention_dim ,
285+             num_attention_heads = num_attention_heads ,
286+             resnet_groups = resnet_groups ,
287+             dual_cross_attention = dual_cross_attention ,
288+             use_linear_projection = use_linear_projection ,
289+             upcast_attention = upcast_attention ,
290+             attention_type = attention_type ,
291+         )
292+     elif  mid_block_type  ==  "UNetMidBlock2DSimpleCrossAttn" :
293+         return  UNetMidBlock2DSimpleCrossAttn (
294+             in_channels = in_channels ,
295+             temb_channels = temb_channels ,
296+             dropout = dropout ,
297+             resnet_eps = resnet_eps ,
298+             resnet_act_fn = resnet_act_fn ,
299+             output_scale_factor = output_scale_factor ,
300+             cross_attention_dim = cross_attention_dim ,
301+             attention_head_dim = attention_head_dim ,
302+             resnet_groups = resnet_groups ,
303+             resnet_time_scale_shift = resnet_time_scale_shift ,
304+             skip_time_act = resnet_skip_time_act ,
305+             only_cross_attention = mid_block_only_cross_attention ,
306+             cross_attention_norm = cross_attention_norm ,
307+         )
308+     elif  mid_block_type  ==  "UNetMidBlock2D" :
309+         return  UNetMidBlock2D (
310+             in_channels = in_channels ,
311+             temb_channels = temb_channels ,
312+             dropout = dropout ,
313+             num_layers = 0 ,
314+             resnet_eps = resnet_eps ,
315+             resnet_act_fn = resnet_act_fn ,
316+             output_scale_factor = output_scale_factor ,
317+             resnet_groups = resnet_groups ,
318+             resnet_time_scale_shift = resnet_time_scale_shift ,
319+             add_attention = False ,
320+         )
321+     elif  mid_block_type  is  None :
322+         return  None 
323+     else :
324+         raise  ValueError (f"unknown mid_block_type : { mid_block_type }  " )
325+ 
326+ 
252327def  get_up_block (
253328    up_block_type : str ,
254329    num_layers : int ,
0 commit comments