@@ -102,6 +102,8 @@ def __init__(
102
102
interpolation_scale : float = None ,
103
103
):
104
104
super ().__init__ ()
105
+
106
+ # Validate inputs.
105
107
if patch_size is not None :
106
108
if norm_type not in ["ada_norm" , "ada_norm_zero" , "ada_norm_single" ]:
107
109
raise NotImplementedError (
@@ -112,10 +114,16 @@ def __init__(
112
114
f"When using a `patch_size` and this `norm_type` ({ norm_type } ), `num_embeds_ada_norm` cannot be None."
113
115
)
114
116
117
+ # Set some common variables used across the board.
115
118
self .use_linear_projection = use_linear_projection
119
+ self .interpolation_scale = interpolation_scale
120
+ self .caption_channels = caption_channels
116
121
self .num_attention_heads = num_attention_heads
117
122
self .attention_head_dim = attention_head_dim
118
- inner_dim = num_attention_heads * attention_head_dim
123
+ self .inner_dim = self .config .num_attention_heads * self .config .attention_head_dim
124
+ self .in_channels = in_channels
125
+ self .out_channels = in_channels if out_channels is None else out_channels
126
+ self .gradient_checkpointing = False
119
127
120
128
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
121
129
# Define whether input is continuous or discrete depending on configuration
@@ -150,104 +158,167 @@ def __init__(
150
158
f" { patch_size } . Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
151
159
)
152
160
153
- # 2. Define input layers
161
+ # 2. Initialize the right blocks.
162
+ # These functions follow a common structure:
163
+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
164
+ # c. Initialize the output blocks and other projection blocks when necessary.
154
165
if self .is_input_continuous :
155
- self .in_channels = in_channels
156
-
157
- self .norm = torch .nn .GroupNorm (num_groups = norm_num_groups , num_channels = in_channels , eps = 1e-6 , affine = True )
158
- if use_linear_projection :
159
- self .proj_in = nn .Linear (in_channels , inner_dim )
160
- else :
161
- self .proj_in = nn .Conv2d (in_channels , inner_dim , kernel_size = 1 , stride = 1 , padding = 0 )
166
+ self ._init_continuous_input (norm_type = norm_type )
162
167
elif self .is_input_vectorized :
163
- assert sample_size is not None , "Transformer2DModel over discrete input must provide sample_size"
164
- assert num_vector_embeds is not None , "Transformer2DModel over discrete input must provide num_embed"
168
+ self ._init_vectorized_inputs (norm_type = norm_type )
169
+ elif self .is_input_patches :
170
+ self ._init_patched_inputs (norm_type = norm_type )
165
171
166
- self .height = sample_size
167
- self .width = sample_size
168
- self .num_vector_embeds = num_vector_embeds
169
- self .num_latent_pixels = self .height * self .width
172
+ def _init_continuous_input (self , norm_type ):
173
+ self .norm = torch .nn .GroupNorm (
174
+ num_groups = self .config .norm_num_groups , num_channels = self .in_channels , eps = 1e-6 , affine = True
175
+ )
176
+ if self .use_linear_projection :
177
+ self .proj_in = torch .nn .Linear (self .in_channels , self .inner_dim )
178
+ else :
179
+ self .proj_in = torch .nn .Conv2d (self .in_channels , self .inner_dim , kernel_size = 1 , stride = 1 , padding = 0 )
170
180
171
- self .latent_image_embedding = ImagePositionalEmbeddings (
172
- num_embed = num_vector_embeds , embed_dim = inner_dim , height = self .height , width = self .width
173
- )
174
- elif self .is_input_patches :
175
- assert sample_size is not None , "Transformer2DModel over patched input must provide sample_size"
181
+ self .transformer_blocks = nn .ModuleList (
182
+ [
183
+ BasicTransformerBlock (
184
+ self .inner_dim ,
185
+ self .config .num_attention_heads ,
186
+ self .config .attention_head_dim ,
187
+ dropout = self .config .dropout ,
188
+ cross_attention_dim = self .config .cross_attention_dim ,
189
+ activation_fn = self .config .activation_fn ,
190
+ num_embeds_ada_norm = self .config .num_embeds_ada_norm ,
191
+ attention_bias = self .config .attention_bias ,
192
+ only_cross_attention = self .config .only_cross_attention ,
193
+ double_self_attention = self .config .double_self_attention ,
194
+ upcast_attention = self .config .upcast_attention ,
195
+ norm_type = norm_type ,
196
+ norm_elementwise_affine = self .config .norm_elementwise_affine ,
197
+ norm_eps = self .config .norm_eps ,
198
+ attention_type = self .config .attention_type ,
199
+ )
200
+ for _ in range (self .config .num_layers )
201
+ ]
202
+ )
176
203
177
- self .height = sample_size
178
- self .width = sample_size
204
+ if self .use_linear_projection :
205
+ self .proj_out = torch .nn .Linear (self .inner_dim , self .out_channels )
206
+ else :
207
+ self .proj_out = torch .nn .Conv2d (self .inner_dim , self .out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
179
208
180
- self .patch_size = patch_size
181
- interpolation_scale = (
182
- interpolation_scale if interpolation_scale is not None else max (self .config .sample_size // 64 , 1 )
183
- )
184
- self .pos_embed = PatchEmbed (
185
- height = sample_size ,
186
- width = sample_size ,
187
- patch_size = patch_size ,
188
- in_channels = in_channels ,
189
- embed_dim = inner_dim ,
190
- interpolation_scale = interpolation_scale ,
191
- )
209
+ def _init_vectorized_inputs (self , norm_type ):
210
+ assert self .config .sample_size is not None , "Transformer2DModel over discrete input must provide sample_size"
211
+ assert (
212
+ self .config .num_vector_embeds is not None
213
+ ), "Transformer2DModel over discrete input must provide num_embed"
214
+
215
+ self .height = self .config .sample_size
216
+ self .width = self .config .sample_size
217
+ self .num_latent_pixels = self .height * self .width
218
+
219
+ self .latent_image_embedding = ImagePositionalEmbeddings (
220
+ num_embed = self .config .num_vector_embeds , embed_dim = self .inner_dim , height = self .height , width = self .width
221
+ )
192
222
193
- # 3. Define transformers blocks
194
223
self .transformer_blocks = nn .ModuleList (
195
224
[
196
225
BasicTransformerBlock (
197
- inner_dim ,
198
- num_attention_heads ,
199
- attention_head_dim ,
200
- dropout = dropout ,
201
- cross_attention_dim = cross_attention_dim ,
202
- activation_fn = activation_fn ,
203
- num_embeds_ada_norm = num_embeds_ada_norm ,
204
- attention_bias = attention_bias ,
205
- only_cross_attention = only_cross_attention ,
206
- double_self_attention = double_self_attention ,
207
- upcast_attention = upcast_attention ,
226
+ self . inner_dim ,
227
+ self . config . num_attention_heads ,
228
+ self . config . attention_head_dim ,
229
+ dropout = self . config . dropout ,
230
+ cross_attention_dim = self . config . cross_attention_dim ,
231
+ activation_fn = self . config . activation_fn ,
232
+ num_embeds_ada_norm = self . config . num_embeds_ada_norm ,
233
+ attention_bias = self . config . attention_bias ,
234
+ only_cross_attention = self . config . only_cross_attention ,
235
+ double_self_attention = self . config . double_self_attention ,
236
+ upcast_attention = self . config . upcast_attention ,
208
237
norm_type = norm_type ,
209
- norm_elementwise_affine = norm_elementwise_affine ,
210
- norm_eps = norm_eps ,
211
- attention_type = attention_type ,
238
+ norm_elementwise_affine = self . config . norm_elementwise_affine ,
239
+ norm_eps = self . config . norm_eps ,
240
+ attention_type = self . config . attention_type ,
212
241
)
213
- for d in range (num_layers )
242
+ for _ in range (self . config . num_layers )
214
243
]
215
244
)
216
245
217
- # 4. Define output layers
218
- self .out_channels = in_channels if out_channels is None else out_channels
219
- if self .is_input_continuous :
220
- # TODO: should use out_channels for continuous projections
221
- if use_linear_projection :
222
- self .proj_out = nn .Linear (inner_dim , in_channels )
223
- else :
224
- self .proj_out = nn .Conv2d (inner_dim , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
225
- elif self .is_input_vectorized :
226
- self .norm_out = nn .LayerNorm (inner_dim )
227
- self .out = nn .Linear (inner_dim , self .num_vector_embeds - 1 )
228
- elif self .is_input_patches and norm_type != "ada_norm_single" :
229
- self .norm_out = nn .LayerNorm (inner_dim , elementwise_affine = False , eps = 1e-6 )
230
- self .proj_out_1 = nn .Linear (inner_dim , 2 * inner_dim )
231
- self .proj_out_2 = nn .Linear (inner_dim , patch_size * patch_size * self .out_channels )
232
- elif self .is_input_patches and norm_type == "ada_norm_single" :
233
- self .norm_out = nn .LayerNorm (inner_dim , elementwise_affine = False , eps = 1e-6 )
234
- self .scale_shift_table = nn .Parameter (torch .randn (2 , inner_dim ) / inner_dim ** 0.5 )
235
- self .proj_out = nn .Linear (inner_dim , patch_size * patch_size * self .out_channels )
236
-
237
- # 5. PixArt-Alpha blocks.
246
+ self .norm_out = nn .LayerNorm (self .inner_dim )
247
+ self .out = nn .Linear (self .inner_dim , self .config .num_vector_embeds - 1 )
248
+
249
+ def _init_patched_inputs (self , norm_type ):
250
+ assert self .config .sample_size is not None , "Transformer2DModel over patched input must provide sample_size"
251
+
252
+ self .height = self .config .sample_size
253
+ self .width = self .config .sample_size
254
+
255
+ self .patch_size = self .config .patch_size
256
+ interpolation_scale = (
257
+ self .config .interpolation_scale
258
+ if self .config .interpolation_scale is not None
259
+ else max (self .config .sample_size // 64 , 1 )
260
+ )
261
+ self .pos_embed = PatchEmbed (
262
+ height = self .config .sample_size ,
263
+ width = self .config .sample_size ,
264
+ patch_size = self .config .patch_size ,
265
+ in_channels = self .in_channels ,
266
+ embed_dim = self .inner_dim ,
267
+ interpolation_scale = interpolation_scale ,
268
+ )
269
+
270
+ self .transformer_blocks = nn .ModuleList (
271
+ [
272
+ BasicTransformerBlock (
273
+ self .inner_dim ,
274
+ self .config .num_attention_heads ,
275
+ self .config .attention_head_dim ,
276
+ dropout = self .config .dropout ,
277
+ cross_attention_dim = self .config .cross_attention_dim ,
278
+ activation_fn = self .config .activation_fn ,
279
+ num_embeds_ada_norm = self .config .num_embeds_ada_norm ,
280
+ attention_bias = self .config .attention_bias ,
281
+ only_cross_attention = self .config .only_cross_attention ,
282
+ double_self_attention = self .config .double_self_attention ,
283
+ upcast_attention = self .config .upcast_attention ,
284
+ norm_type = norm_type ,
285
+ norm_elementwise_affine = self .config .norm_elementwise_affine ,
286
+ norm_eps = self .config .norm_eps ,
287
+ attention_type = self .config .attention_type ,
288
+ )
289
+ for _ in range (self .config .num_layers )
290
+ ]
291
+ )
292
+
293
+ if self .config .norm_type != "ada_norm_single" :
294
+ self .norm_out = nn .LayerNorm (self .inner_dim , elementwise_affine = False , eps = 1e-6 )
295
+ self .proj_out_1 = nn .Linear (self .inner_dim , 2 * self .inner_dim )
296
+ self .proj_out_2 = nn .Linear (
297
+ self .inner_dim , self .config .patch_size * self .config .patch_size * self .out_channels
298
+ )
299
+ elif self .config .norm_type == "ada_norm_single" :
300
+ self .norm_out = nn .LayerNorm (self .inner_dim , elementwise_affine = False , eps = 1e-6 )
301
+ self .scale_shift_table = nn .Parameter (torch .randn (2 , self .inner_dim ) / self .inner_dim ** 0.5 )
302
+ self .proj_out = nn .Linear (
303
+ self .inner_dim , self .config .patch_size * self .config .patch_size * self .out_channels
304
+ )
305
+
306
+ # PixArt-Alpha blocks.
238
307
self .adaln_single = None
239
308
self .use_additional_conditions = False
240
- if norm_type == "ada_norm_single" :
309
+ if self . config . norm_type == "ada_norm_single" :
241
310
self .use_additional_conditions = self .config .sample_size == 128
242
311
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
243
312
# additional conditions until we find better name
244
- self .adaln_single = AdaLayerNormSingle (inner_dim , use_additional_conditions = self .use_additional_conditions )
313
+ self .adaln_single = AdaLayerNormSingle (
314
+ self .inner_dim , use_additional_conditions = self .use_additional_conditions
315
+ )
245
316
246
317
self .caption_projection = None
247
- if caption_channels is not None :
248
- self .caption_projection = PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = inner_dim )
249
-
250
- self . gradient_checkpointing = False
318
+ if self . caption_channels is not None :
319
+ self .caption_projection = PixArtAlphaTextProjection (
320
+ in_features = self . caption_channels , hidden_size = self . inner_dim
321
+ )
251
322
252
323
def _set_gradient_checkpointing (self , module , value = False ):
253
324
if hasattr (module , "gradient_checkpointing" ):
@@ -361,7 +432,7 @@ def forward(
361
432
)
362
433
363
434
# 2. Blocks
364
- if self .caption_projection is not None :
435
+ if self .is_input_patches and self . caption_projection is not None :
365
436
batch_size = hidden_states .shape [0 ]
366
437
encoder_hidden_states = self .caption_projection (encoder_hidden_states )
367
438
encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
0 commit comments