@@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
8585
8686 def forward (
8787 self ,
88- hidden_states : torch .FloatTensor ,
89- temb : torch .FloatTensor ,
90- image_rotary_emb = None ,
91- joint_attention_kwargs = None ,
92- ):
88+ hidden_states : torch .Tensor ,
89+ temb : torch .Tensor ,
90+ image_rotary_emb : Optional [ Tuple [ torch . Tensor , torch . Tensor ]] = None ,
91+ joint_attention_kwargs : Optional [ Dict [ str , Any ]] = None ,
92+ ) -> torch . Tensor :
9393 residual = hidden_states
9494 norm_hidden_states , gate = self .norm (hidden_states , emb = temb )
9595 mlp_hidden_states = self .act_mlp (self .proj_mlp (norm_hidden_states ))
@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
117117
118118 Reference: https://arxiv.org/abs/2403.03206
119119
120- Parameters:
121- dim (`int`): The number of channels in the input and output.
122- num_attention_heads (`int`): The number of heads to use for multi-head attention.
123- attention_head_dim (`int`): The number of channels in each head.
124- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
125- processing of `context` conditions.
120+ Args:
121+ dim (`int`):
122+ The embedding dimension of the block.
123+ num_attention_heads (`int`):
124+ The number of attention heads to use.
125+ attention_head_dim (`int`):
126+ The number of dimensions to use for each attention head.
127+ qk_norm (`str`, defaults to `"rms_norm"`):
128+ The normalization to use for the query and key tensors.
129+ eps (`float`, defaults to `1e-6`):
130+ The epsilon value to use for the normalization.
126131 """
127132
128- def __init__ (self , dim , num_attention_heads , attention_head_dim , qk_norm = "rms_norm" , eps = 1e-6 ):
133+ def __init__ (
134+ self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6
135+ ):
129136 super ().__init__ ()
130137
131138 self .norm1 = AdaLayerNormZero (dim )
@@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
164171
165172 def forward (
166173 self ,
167- hidden_states : torch .FloatTensor ,
168- encoder_hidden_states : torch .FloatTensor ,
169- temb : torch .FloatTensor ,
170- image_rotary_emb = None ,
171- joint_attention_kwargs = None ,
172- ):
174+ hidden_states : torch .Tensor ,
175+ encoder_hidden_states : torch .Tensor ,
176+ temb : torch .Tensor ,
177+ image_rotary_emb : Optional [ Tuple [ torch . Tensor , torch . Tensor ]] = None ,
178+ joint_attention_kwargs : Optional [ Dict [ str , Any ]] = None ,
179+ ) -> Tuple [ torch . Tensor , torch . Tensor ] :
173180 norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
174181
175182 norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
227234
228235 Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
229236
230- Parameters:
231- patch_size (`int`): Patch size to turn the input data into small patches.
232- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
233- num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
234- num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
235- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
236- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
237- joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
238- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
239- guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
237+ Args:
238+ patch_size (`int`, defaults to `1`):
239+ Patch size to turn the input data into small patches.
240+ in_channels (`int`, defaults to `64`):
241+ The number of channels in the input.
242+ out_channels (`int`, *optional*, defaults to `None`):
243+ The number of channels in the output. If not specified, it defaults to `in_channels`.
244+ num_layers (`int`, defaults to `19`):
245+ The number of layers of dual stream DiT blocks to use.
246+ num_single_layers (`int`, defaults to `38`):
247+ The number of layers of single stream DiT blocks to use.
248+ attention_head_dim (`int`, defaults to `128`):
249+ The number of dimensions to use for each attention head.
250+ num_attention_heads (`int`, defaults to `24`):
251+ The number of attention heads to use.
252+ joint_attention_dim (`int`, defaults to `4096`):
253+ The number of dimensions to use for the joint attention (embedding/channel dimension of
254+ `encoder_hidden_states`).
255+ pooled_projection_dim (`int`, defaults to `768`):
256+ The number of dimensions to use for the pooled projection.
257+ guidance_embeds (`bool`, defaults to `False`):
258+ Whether to use guidance embeddings for guidance-distilled variant of the model.
259+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
260+ The dimensions to use for the rotary positional embeddings.
240261 """
241262
242263 _supports_gradient_checkpointing = True
@@ -259,39 +280,39 @@ def __init__(
259280 ):
260281 super ().__init__ ()
261282 self .out_channels = out_channels or in_channels
262- self .inner_dim = self . config . num_attention_heads * self . config . attention_head_dim
283+ self .inner_dim = num_attention_heads * attention_head_dim
263284
264285 self .pos_embed = FluxPosEmbed (theta = 10000 , axes_dim = axes_dims_rope )
265286
266287 text_time_guidance_cls = (
267288 CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
268289 )
269290 self .time_text_embed = text_time_guidance_cls (
270- embedding_dim = self .inner_dim , pooled_projection_dim = self . config . pooled_projection_dim
291+ embedding_dim = self .inner_dim , pooled_projection_dim = pooled_projection_dim
271292 )
272293
273- self .context_embedder = nn .Linear (self . config . joint_attention_dim , self .inner_dim )
274- self .x_embedder = nn .Linear (self . config . in_channels , self .inner_dim )
294+ self .context_embedder = nn .Linear (joint_attention_dim , self .inner_dim )
295+ self .x_embedder = nn .Linear (in_channels , self .inner_dim )
275296
276297 self .transformer_blocks = nn .ModuleList (
277298 [
278299 FluxTransformerBlock (
279300 dim = self .inner_dim ,
280- num_attention_heads = self . config . num_attention_heads ,
281- attention_head_dim = self . config . attention_head_dim ,
301+ num_attention_heads = num_attention_heads ,
302+ attention_head_dim = attention_head_dim ,
282303 )
283- for i in range (self . config . num_layers )
304+ for _ in range (num_layers )
284305 ]
285306 )
286307
287308 self .single_transformer_blocks = nn .ModuleList (
288309 [
289310 FluxSingleTransformerBlock (
290311 dim = self .inner_dim ,
291- num_attention_heads = self . config . num_attention_heads ,
292- attention_head_dim = self . config . attention_head_dim ,
312+ num_attention_heads = num_attention_heads ,
313+ attention_head_dim = attention_head_dim ,
293314 )
294- for i in range (self . config . num_single_layers )
315+ for _ in range (num_single_layers )
295316 ]
296317 )
297318
@@ -418,16 +439,16 @@ def forward(
418439 controlnet_single_block_samples = None ,
419440 return_dict : bool = True ,
420441 controlnet_blocks_repeat : bool = False ,
421- ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
442+ ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
422443 """
423444 The [`FluxTransformer2DModel`] forward method.
424445
425446 Args:
426- hidden_states (`torch.FloatTensor ` of shape `(batch size, channel, height, width )`):
447+ hidden_states (`torch.Tensor ` of shape `(batch_size, image_sequence_length, in_channels )`):
427448 Input `hidden_states`.
428- encoder_hidden_states (`torch.FloatTensor ` of shape `(batch size, sequence_len, embed_dims )`):
449+ encoder_hidden_states (`torch.Tensor ` of shape `(batch_size, text_sequence_length, joint_attention_dim )`):
429450 Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
430- pooled_projections (`torch.FloatTensor ` of shape `(batch_size, projection_dim)`): Embeddings projected
451+ pooled_projections (`torch.Tensor ` of shape `(batch_size, projection_dim)`): Embeddings projected
431452 from the embeddings of input conditions.
432453 timestep ( `torch.LongTensor`):
433454 Used to indicate denoising step.
0 commit comments