@@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
85
85
86
86
def forward (
87
87
self ,
88
- hidden_states : torch .FloatTensor ,
89
- temb : torch .FloatTensor ,
90
- image_rotary_emb = None ,
91
- joint_attention_kwargs = None ,
92
- ):
88
+ hidden_states : torch .Tensor ,
89
+ temb : torch .Tensor ,
90
+ image_rotary_emb : Optional [ Tuple [ torch . Tensor , torch . Tensor ]] = None ,
91
+ joint_attention_kwargs : Optional [ Dict [ str , Any ]] = None ,
92
+ ) -> torch . Tensor :
93
93
residual = hidden_states
94
94
norm_hidden_states , gate = self .norm (hidden_states , emb = temb )
95
95
mlp_hidden_states = self .act_mlp (self .proj_mlp (norm_hidden_states ))
@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
117
117
118
118
Reference: https://arxiv.org/abs/2403.03206
119
119
120
- Parameters:
121
- dim (`int`): The number of channels in the input and output.
122
- num_attention_heads (`int`): The number of heads to use for multi-head attention.
123
- attention_head_dim (`int`): The number of channels in each head.
124
- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
125
- processing of `context` conditions.
120
+ Args:
121
+ dim (`int`):
122
+ The embedding dimension of the block.
123
+ num_attention_heads (`int`):
124
+ The number of attention heads to use.
125
+ attention_head_dim (`int`):
126
+ The number of dimensions to use for each attention head.
127
+ qk_norm (`str`, defaults to `"rms_norm"`):
128
+ The normalization to use for the query and key tensors.
129
+ eps (`float`, defaults to `1e-6`):
130
+ The epsilon value to use for the normalization.
126
131
"""
127
132
128
- def __init__ (self , dim , num_attention_heads , attention_head_dim , qk_norm = "rms_norm" , eps = 1e-6 ):
133
+ def __init__ (
134
+ self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6
135
+ ):
129
136
super ().__init__ ()
130
137
131
138
self .norm1 = AdaLayerNormZero (dim )
@@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
164
171
165
172
def forward (
166
173
self ,
167
- hidden_states : torch .FloatTensor ,
168
- encoder_hidden_states : torch .FloatTensor ,
169
- temb : torch .FloatTensor ,
170
- image_rotary_emb = None ,
171
- joint_attention_kwargs = None ,
172
- ):
174
+ hidden_states : torch .Tensor ,
175
+ encoder_hidden_states : torch .Tensor ,
176
+ temb : torch .Tensor ,
177
+ image_rotary_emb : Optional [ Tuple [ torch . Tensor , torch . Tensor ]] = None ,
178
+ joint_attention_kwargs : Optional [ Dict [ str , Any ]] = None ,
179
+ ) -> Tuple [ torch . Tensor , torch . Tensor ] :
173
180
norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (hidden_states , emb = temb )
174
181
175
182
norm_encoder_hidden_states , c_gate_msa , c_shift_mlp , c_scale_mlp , c_gate_mlp = self .norm1_context (
@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
227
234
228
235
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
229
236
230
- Parameters:
231
- patch_size (`int`): Patch size to turn the input data into small patches.
232
- in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
233
- num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
234
- num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
235
- attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
236
- num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
237
- joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
238
- pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
239
- guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
237
+ Args:
238
+ patch_size (`int`, defaults to `1`):
239
+ Patch size to turn the input data into small patches.
240
+ in_channels (`int`, defaults to `64`):
241
+ The number of channels in the input.
242
+ out_channels (`int`, *optional*, defaults to `None`):
243
+ The number of channels in the output. If not specified, it defaults to `in_channels`.
244
+ num_layers (`int`, defaults to `19`):
245
+ The number of layers of dual stream DiT blocks to use.
246
+ num_single_layers (`int`, defaults to `38`):
247
+ The number of layers of single stream DiT blocks to use.
248
+ attention_head_dim (`int`, defaults to `128`):
249
+ The number of dimensions to use for each attention head.
250
+ num_attention_heads (`int`, defaults to `24`):
251
+ The number of attention heads to use.
252
+ joint_attention_dim (`int`, defaults to `4096`):
253
+ The number of dimensions to use for the joint attention (embedding/channel dimension of
254
+ `encoder_hidden_states`).
255
+ pooled_projection_dim (`int`, defaults to `768`):
256
+ The number of dimensions to use for the pooled projection.
257
+ guidance_embeds (`bool`, defaults to `False`):
258
+ Whether to use guidance embeddings for guidance-distilled variant of the model.
259
+ axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
260
+ The dimensions to use for the rotary positional embeddings.
240
261
"""
241
262
242
263
_supports_gradient_checkpointing = True
@@ -259,39 +280,39 @@ def __init__(
259
280
):
260
281
super ().__init__ ()
261
282
self .out_channels = out_channels or in_channels
262
- self .inner_dim = self . config . num_attention_heads * self . config . attention_head_dim
283
+ self .inner_dim = num_attention_heads * attention_head_dim
263
284
264
285
self .pos_embed = FluxPosEmbed (theta = 10000 , axes_dim = axes_dims_rope )
265
286
266
287
text_time_guidance_cls = (
267
288
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
268
289
)
269
290
self .time_text_embed = text_time_guidance_cls (
270
- embedding_dim = self .inner_dim , pooled_projection_dim = self . config . pooled_projection_dim
291
+ embedding_dim = self .inner_dim , pooled_projection_dim = pooled_projection_dim
271
292
)
272
293
273
- self .context_embedder = nn .Linear (self . config . joint_attention_dim , self .inner_dim )
274
- self .x_embedder = nn .Linear (self . config . in_channels , self .inner_dim )
294
+ self .context_embedder = nn .Linear (joint_attention_dim , self .inner_dim )
295
+ self .x_embedder = nn .Linear (in_channels , self .inner_dim )
275
296
276
297
self .transformer_blocks = nn .ModuleList (
277
298
[
278
299
FluxTransformerBlock (
279
300
dim = self .inner_dim ,
280
- num_attention_heads = self . config . num_attention_heads ,
281
- attention_head_dim = self . config . attention_head_dim ,
301
+ num_attention_heads = num_attention_heads ,
302
+ attention_head_dim = attention_head_dim ,
282
303
)
283
- for i in range (self . config . num_layers )
304
+ for _ in range (num_layers )
284
305
]
285
306
)
286
307
287
308
self .single_transformer_blocks = nn .ModuleList (
288
309
[
289
310
FluxSingleTransformerBlock (
290
311
dim = self .inner_dim ,
291
- num_attention_heads = self . config . num_attention_heads ,
292
- attention_head_dim = self . config . attention_head_dim ,
312
+ num_attention_heads = num_attention_heads ,
313
+ attention_head_dim = attention_head_dim ,
293
314
)
294
- for i in range (self . config . num_single_layers )
315
+ for _ in range (num_single_layers )
295
316
]
296
317
)
297
318
@@ -418,16 +439,16 @@ def forward(
418
439
controlnet_single_block_samples = None ,
419
440
return_dict : bool = True ,
420
441
controlnet_blocks_repeat : bool = False ,
421
- ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
442
+ ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
422
443
"""
423
444
The [`FluxTransformer2DModel`] forward method.
424
445
425
446
Args:
426
- hidden_states (`torch.FloatTensor ` of shape `(batch size, channel, height, width )`):
447
+ hidden_states (`torch.Tensor ` of shape `(batch_size, image_sequence_length, in_channels )`):
427
448
Input `hidden_states`.
428
- encoder_hidden_states (`torch.FloatTensor ` of shape `(batch size, sequence_len, embed_dims )`):
449
+ encoder_hidden_states (`torch.Tensor ` of shape `(batch_size, text_sequence_length, joint_attention_dim )`):
429
450
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
430
- pooled_projections (`torch.FloatTensor ` of shape `(batch_size, projection_dim)`): Embeddings projected
451
+ pooled_projections (`torch.Tensor ` of shape `(batch_size, projection_dim)`): Embeddings projected
431
452
from the embeddings of input conditions.
432
453
timestep ( `torch.LongTensor`):
433
454
Used to indicate denoising step.
0 commit comments