@@ -69,6 +69,7 @@ def __init__(
69
69
norm_num_groups = 32 ,
70
70
norm_eps = 1e-5 ,
71
71
cross_attention_dim = 1280 ,
72
+ transformer_layers_per_block = 1 ,
72
73
attention_head_dim = 8 ,
73
74
use_linear_projection = False ,
74
75
upcast_attention = False ,
@@ -129,6 +130,9 @@ def __init__(
129
130
if isinstance (attention_head_dim , int ):
130
131
attention_head_dim = (attention_head_dim ,) * len (down_block_types )
131
132
133
+ if isinstance (transformer_layers_per_block , int ):
134
+ transformer_layers_per_block = [transformer_layers_per_block ] * len (down_block_types )
135
+
132
136
# down
133
137
output_channel = block_out_channels [0 ]
134
138
@@ -142,6 +146,7 @@ def __init__(
142
146
143
147
down_block = get_down_block (
144
148
down_block_type ,
149
+ transformer_layers_per_block = transformer_layers_per_block [i ],
145
150
num_layers = layers_per_block ,
146
151
in_channels = input_channel ,
147
152
out_channels = output_channel ,
@@ -151,6 +156,7 @@ def __init__(
151
156
cross_attention_dim = cross_attention_dim ,
152
157
attn_num_head_channels = attention_head_dim [i ],
153
158
downsample_padding = downsample_padding ,
159
+ add_downsample = not is_final_block ,
154
160
)
155
161
self .down_blocks .append (down_block )
156
162
0 commit comments