@@ -102,6 +102,8 @@ def __init__(
102102 interpolation_scale : float = None ,
103103 ):
104104 super ().__init__ ()
105+
106+ # Validate inputs.
105107 if patch_size is not None :
106108 if norm_type not in ["ada_norm" , "ada_norm_zero" , "ada_norm_single" ]:
107109 raise NotImplementedError (
@@ -112,10 +114,16 @@ def __init__(
112114 f"When using a `patch_size` and this `norm_type` ({ norm_type } ), `num_embeds_ada_norm` cannot be None."
113115 )
114116
117+ # Set some common variables used across the board.
115118 self .use_linear_projection = use_linear_projection
119+ self .interpolation_scale = interpolation_scale
120+ self .caption_channels = caption_channels
116121 self .num_attention_heads = num_attention_heads
117122 self .attention_head_dim = attention_head_dim
118- inner_dim = num_attention_heads * attention_head_dim
123+ self .inner_dim = self .config .num_attention_heads * self .config .attention_head_dim
124+ self .in_channels = in_channels
125+ self .out_channels = in_channels if out_channels is None else out_channels
126+ self .gradient_checkpointing = False
119127
120128 # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
121129 # Define whether input is continuous or discrete depending on configuration
@@ -150,104 +158,167 @@ def __init__(
150158 f" { patch_size } . Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
151159 )
152160
153- # 2. Define input layers
161+ # 2. Initialize the right blocks.
162+ # These functions follow a common structure:
163+ # a. Initialize the input blocks. b. Initialize the transformer blocks.
164+ # c. Initialize the output blocks and other projection blocks when necessary.
154165 if self .is_input_continuous :
155- self .in_channels = in_channels
156-
157- self .norm = torch .nn .GroupNorm (num_groups = norm_num_groups , num_channels = in_channels , eps = 1e-6 , affine = True )
158- if use_linear_projection :
159- self .proj_in = nn .Linear (in_channels , inner_dim )
160- else :
161- self .proj_in = nn .Conv2d (in_channels , inner_dim , kernel_size = 1 , stride = 1 , padding = 0 )
166+ self ._init_continuous_input (norm_type = norm_type )
162167 elif self .is_input_vectorized :
163- assert sample_size is not None , "Transformer2DModel over discrete input must provide sample_size"
164- assert num_vector_embeds is not None , "Transformer2DModel over discrete input must provide num_embed"
168+ self ._init_vectorized_inputs (norm_type = norm_type )
169+ elif self .is_input_patches :
170+ self ._init_patched_inputs (norm_type = norm_type )
165171
166- self .height = sample_size
167- self .width = sample_size
168- self .num_vector_embeds = num_vector_embeds
169- self .num_latent_pixels = self .height * self .width
172+ def _init_continuous_input (self , norm_type ):
173+ self .norm = torch .nn .GroupNorm (
174+ num_groups = self .config .norm_num_groups , num_channels = self .in_channels , eps = 1e-6 , affine = True
175+ )
176+ if self .use_linear_projection :
177+ self .proj_in = torch .nn .Linear (self .in_channels , self .inner_dim )
178+ else :
179+ self .proj_in = torch .nn .Conv2d (self .in_channels , self .inner_dim , kernel_size = 1 , stride = 1 , padding = 0 )
170180
171- self .latent_image_embedding = ImagePositionalEmbeddings (
172- num_embed = num_vector_embeds , embed_dim = inner_dim , height = self .height , width = self .width
173- )
174- elif self .is_input_patches :
175- assert sample_size is not None , "Transformer2DModel over patched input must provide sample_size"
181+ self .transformer_blocks = nn .ModuleList (
182+ [
183+ BasicTransformerBlock (
184+ self .inner_dim ,
185+ self .config .num_attention_heads ,
186+ self .config .attention_head_dim ,
187+ dropout = self .config .dropout ,
188+ cross_attention_dim = self .config .cross_attention_dim ,
189+ activation_fn = self .config .activation_fn ,
190+ num_embeds_ada_norm = self .config .num_embeds_ada_norm ,
191+ attention_bias = self .config .attention_bias ,
192+ only_cross_attention = self .config .only_cross_attention ,
193+ double_self_attention = self .config .double_self_attention ,
194+ upcast_attention = self .config .upcast_attention ,
195+ norm_type = norm_type ,
196+ norm_elementwise_affine = self .config .norm_elementwise_affine ,
197+ norm_eps = self .config .norm_eps ,
198+ attention_type = self .config .attention_type ,
199+ )
200+ for _ in range (self .config .num_layers )
201+ ]
202+ )
176203
177- self .height = sample_size
178- self .width = sample_size
204+ if self .use_linear_projection :
205+ self .proj_out = torch .nn .Linear (self .inner_dim , self .out_channels )
206+ else :
207+ self .proj_out = torch .nn .Conv2d (self .inner_dim , self .out_channels , kernel_size = 1 , stride = 1 , padding = 0 )
179208
180- self .patch_size = patch_size
181- interpolation_scale = (
182- interpolation_scale if interpolation_scale is not None else max (self .config .sample_size // 64 , 1 )
183- )
184- self .pos_embed = PatchEmbed (
185- height = sample_size ,
186- width = sample_size ,
187- patch_size = patch_size ,
188- in_channels = in_channels ,
189- embed_dim = inner_dim ,
190- interpolation_scale = interpolation_scale ,
191- )
209+ def _init_vectorized_inputs (self , norm_type ):
210+ assert self .config .sample_size is not None , "Transformer2DModel over discrete input must provide sample_size"
211+ assert (
212+ self .config .num_vector_embeds is not None
213+ ), "Transformer2DModel over discrete input must provide num_embed"
214+
215+ self .height = self .config .sample_size
216+ self .width = self .config .sample_size
217+ self .num_latent_pixels = self .height * self .width
218+
219+ self .latent_image_embedding = ImagePositionalEmbeddings (
220+ num_embed = self .config .num_vector_embeds , embed_dim = self .inner_dim , height = self .height , width = self .width
221+ )
192222
193- # 3. Define transformers blocks
194223 self .transformer_blocks = nn .ModuleList (
195224 [
196225 BasicTransformerBlock (
197- inner_dim ,
198- num_attention_heads ,
199- attention_head_dim ,
200- dropout = dropout ,
201- cross_attention_dim = cross_attention_dim ,
202- activation_fn = activation_fn ,
203- num_embeds_ada_norm = num_embeds_ada_norm ,
204- attention_bias = attention_bias ,
205- only_cross_attention = only_cross_attention ,
206- double_self_attention = double_self_attention ,
207- upcast_attention = upcast_attention ,
226+ self . inner_dim ,
227+ self . config . num_attention_heads ,
228+ self . config . attention_head_dim ,
229+ dropout = self . config . dropout ,
230+ cross_attention_dim = self . config . cross_attention_dim ,
231+ activation_fn = self . config . activation_fn ,
232+ num_embeds_ada_norm = self . config . num_embeds_ada_norm ,
233+ attention_bias = self . config . attention_bias ,
234+ only_cross_attention = self . config . only_cross_attention ,
235+ double_self_attention = self . config . double_self_attention ,
236+ upcast_attention = self . config . upcast_attention ,
208237 norm_type = norm_type ,
209- norm_elementwise_affine = norm_elementwise_affine ,
210- norm_eps = norm_eps ,
211- attention_type = attention_type ,
238+ norm_elementwise_affine = self . config . norm_elementwise_affine ,
239+ norm_eps = self . config . norm_eps ,
240+ attention_type = self . config . attention_type ,
212241 )
213- for d in range (num_layers )
242+ for _ in range (self . config . num_layers )
214243 ]
215244 )
216245
217- # 4. Define output layers
218- self .out_channels = in_channels if out_channels is None else out_channels
219- if self .is_input_continuous :
220- # TODO: should use out_channels for continuous projections
221- if use_linear_projection :
222- self .proj_out = nn .Linear (inner_dim , in_channels )
223- else :
224- self .proj_out = nn .Conv2d (inner_dim , in_channels , kernel_size = 1 , stride = 1 , padding = 0 )
225- elif self .is_input_vectorized :
226- self .norm_out = nn .LayerNorm (inner_dim )
227- self .out = nn .Linear (inner_dim , self .num_vector_embeds - 1 )
228- elif self .is_input_patches and norm_type != "ada_norm_single" :
229- self .norm_out = nn .LayerNorm (inner_dim , elementwise_affine = False , eps = 1e-6 )
230- self .proj_out_1 = nn .Linear (inner_dim , 2 * inner_dim )
231- self .proj_out_2 = nn .Linear (inner_dim , patch_size * patch_size * self .out_channels )
232- elif self .is_input_patches and norm_type == "ada_norm_single" :
233- self .norm_out = nn .LayerNorm (inner_dim , elementwise_affine = False , eps = 1e-6 )
234- self .scale_shift_table = nn .Parameter (torch .randn (2 , inner_dim ) / inner_dim ** 0.5 )
235- self .proj_out = nn .Linear (inner_dim , patch_size * patch_size * self .out_channels )
236-
237- # 5. PixArt-Alpha blocks.
246+ self .norm_out = nn .LayerNorm (self .inner_dim )
247+ self .out = nn .Linear (self .inner_dim , self .config .num_vector_embeds - 1 )
248+
249+ def _init_patched_inputs (self , norm_type ):
250+ assert self .config .sample_size is not None , "Transformer2DModel over patched input must provide sample_size"
251+
252+ self .height = self .config .sample_size
253+ self .width = self .config .sample_size
254+
255+ self .patch_size = self .config .patch_size
256+ interpolation_scale = (
257+ self .config .interpolation_scale
258+ if self .config .interpolation_scale is not None
259+ else max (self .config .sample_size // 64 , 1 )
260+ )
261+ self .pos_embed = PatchEmbed (
262+ height = self .config .sample_size ,
263+ width = self .config .sample_size ,
264+ patch_size = self .config .patch_size ,
265+ in_channels = self .in_channels ,
266+ embed_dim = self .inner_dim ,
267+ interpolation_scale = interpolation_scale ,
268+ )
269+
270+ self .transformer_blocks = nn .ModuleList (
271+ [
272+ BasicTransformerBlock (
273+ self .inner_dim ,
274+ self .config .num_attention_heads ,
275+ self .config .attention_head_dim ,
276+ dropout = self .config .dropout ,
277+ cross_attention_dim = self .config .cross_attention_dim ,
278+ activation_fn = self .config .activation_fn ,
279+ num_embeds_ada_norm = self .config .num_embeds_ada_norm ,
280+ attention_bias = self .config .attention_bias ,
281+ only_cross_attention = self .config .only_cross_attention ,
282+ double_self_attention = self .config .double_self_attention ,
283+ upcast_attention = self .config .upcast_attention ,
284+ norm_type = norm_type ,
285+ norm_elementwise_affine = self .config .norm_elementwise_affine ,
286+ norm_eps = self .config .norm_eps ,
287+ attention_type = self .config .attention_type ,
288+ )
289+ for _ in range (self .config .num_layers )
290+ ]
291+ )
292+
293+ if self .config .norm_type != "ada_norm_single" :
294+ self .norm_out = nn .LayerNorm (self .inner_dim , elementwise_affine = False , eps = 1e-6 )
295+ self .proj_out_1 = nn .Linear (self .inner_dim , 2 * self .inner_dim )
296+ self .proj_out_2 = nn .Linear (
297+ self .inner_dim , self .config .patch_size * self .config .patch_size * self .out_channels
298+ )
299+ elif self .config .norm_type == "ada_norm_single" :
300+ self .norm_out = nn .LayerNorm (self .inner_dim , elementwise_affine = False , eps = 1e-6 )
301+ self .scale_shift_table = nn .Parameter (torch .randn (2 , self .inner_dim ) / self .inner_dim ** 0.5 )
302+ self .proj_out = nn .Linear (
303+ self .inner_dim , self .config .patch_size * self .config .patch_size * self .out_channels
304+ )
305+
306+ # PixArt-Alpha blocks.
238307 self .adaln_single = None
239308 self .use_additional_conditions = False
240- if norm_type == "ada_norm_single" :
309+ if self . config . norm_type == "ada_norm_single" :
241310 self .use_additional_conditions = self .config .sample_size == 128
242311 # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
243312 # additional conditions until we find better name
244- self .adaln_single = AdaLayerNormSingle (inner_dim , use_additional_conditions = self .use_additional_conditions )
313+ self .adaln_single = AdaLayerNormSingle (
314+ self .inner_dim , use_additional_conditions = self .use_additional_conditions
315+ )
245316
246317 self .caption_projection = None
247- if caption_channels is not None :
248- self .caption_projection = PixArtAlphaTextProjection (in_features = caption_channels , hidden_size = inner_dim )
249-
250- self . gradient_checkpointing = False
318+ if self . caption_channels is not None :
319+ self .caption_projection = PixArtAlphaTextProjection (
320+ in_features = self . caption_channels , hidden_size = self . inner_dim
321+ )
251322
252323 def _set_gradient_checkpointing (self , module , value = False ):
253324 if hasattr (module , "gradient_checkpointing" ):
@@ -361,7 +432,7 @@ def forward(
361432 )
362433
363434 # 2. Blocks
364- if self .caption_projection is not None :
435+ if self .is_input_patches and self . caption_projection is not None :
365436 batch_size = hidden_states .shape [0 ]
366437 encoder_hidden_states = self .caption_projection (encoder_hidden_states )
367438 encoder_hidden_states = encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
0 commit comments