@@ -28,7 +28,9 @@ def get_down_block(
28
28
resnet_groups = None ,
29
29
cross_attention_dim = None ,
30
30
downsample_padding = None ,
31
- dual_cross_attention = None ,
31
+ dual_cross_attention = False ,
32
+ use_linear_projection = False ,
33
+ only_cross_attention = False ,
32
34
):
33
35
down_block_type = down_block_type [7 :] if down_block_type .startswith ("UNetRes" ) else down_block_type
34
36
if down_block_type == "DownBlockFlat" :
@@ -58,6 +60,9 @@ def get_down_block(
58
60
downsample_padding = downsample_padding ,
59
61
cross_attention_dim = cross_attention_dim ,
60
62
attn_num_head_channels = attn_num_head_channels ,
63
+ dual_cross_attention = dual_cross_attention ,
64
+ use_linear_projection = use_linear_projection ,
65
+ only_cross_attention = only_cross_attention ,
61
66
)
62
67
raise ValueError (f"{ down_block_type } is not supported." )
63
68
@@ -75,7 +80,9 @@ def get_up_block(
75
80
attn_num_head_channels ,
76
81
resnet_groups = None ,
77
82
cross_attention_dim = None ,
78
- dual_cross_attention = None ,
83
+ dual_cross_attention = False ,
84
+ use_linear_projection = False ,
85
+ only_cross_attention = False ,
79
86
):
80
87
up_block_type = up_block_type [7 :] if up_block_type .startswith ("UNetRes" ) else up_block_type
81
88
if up_block_type == "UpBlockFlat" :
@@ -105,6 +112,9 @@ def get_up_block(
105
112
resnet_groups = resnet_groups ,
106
113
cross_attention_dim = cross_attention_dim ,
107
114
attn_num_head_channels = attn_num_head_channels ,
115
+ dual_cross_attention = dual_cross_attention ,
116
+ use_linear_projection = use_linear_projection ,
117
+ only_cross_attention = only_cross_attention ,
108
118
)
109
119
raise ValueError (f"{ up_block_type } is not supported." )
110
120
0 commit comments