Skip to content

Commit 476795c

Browse files
authored
Update Flux docstrings (#10423)
update
1 parent 3cb6686 commit 476795c

File tree

1 file changed

+63
-42
lines changed

1 file changed

+63
-42
lines changed

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 63 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
8585

8686
def forward(
8787
self,
88-
hidden_states: torch.FloatTensor,
89-
temb: torch.FloatTensor,
90-
image_rotary_emb=None,
91-
joint_attention_kwargs=None,
92-
):
88+
hidden_states: torch.Tensor,
89+
temb: torch.Tensor,
90+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
91+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
92+
) -> torch.Tensor:
9393
residual = hidden_states
9494
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
9595
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
@@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module):
117117
118118
Reference: https://arxiv.org/abs/2403.03206
119119
120-
Parameters:
121-
dim (`int`): The number of channels in the input and output.
122-
num_attention_heads (`int`): The number of heads to use for multi-head attention.
123-
attention_head_dim (`int`): The number of channels in each head.
124-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
125-
processing of `context` conditions.
120+
Args:
121+
dim (`int`):
122+
The embedding dimension of the block.
123+
num_attention_heads (`int`):
124+
The number of attention heads to use.
125+
attention_head_dim (`int`):
126+
The number of dimensions to use for each attention head.
127+
qk_norm (`str`, defaults to `"rms_norm"`):
128+
The normalization to use for the query and key tensors.
129+
eps (`float`, defaults to `1e-6`):
130+
The epsilon value to use for the normalization.
126131
"""
127132

128-
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
133+
def __init__(
134+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
135+
):
129136
super().__init__()
130137

131138
self.norm1 = AdaLayerNormZero(dim)
@@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no
164171

165172
def forward(
166173
self,
167-
hidden_states: torch.FloatTensor,
168-
encoder_hidden_states: torch.FloatTensor,
169-
temb: torch.FloatTensor,
170-
image_rotary_emb=None,
171-
joint_attention_kwargs=None,
172-
):
174+
hidden_states: torch.Tensor,
175+
encoder_hidden_states: torch.Tensor,
176+
temb: torch.Tensor,
177+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
178+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
179+
) -> Tuple[torch.Tensor, torch.Tensor]:
173180
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
174181

175182
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
@@ -227,16 +234,30 @@ class FluxTransformer2DModel(
227234
228235
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
229236
230-
Parameters:
231-
patch_size (`int`): Patch size to turn the input data into small patches.
232-
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
233-
num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use.
234-
num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use.
235-
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
236-
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
237-
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
238-
pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`.
239-
guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings.
237+
Args:
238+
patch_size (`int`, defaults to `1`):
239+
Patch size to turn the input data into small patches.
240+
in_channels (`int`, defaults to `64`):
241+
The number of channels in the input.
242+
out_channels (`int`, *optional*, defaults to `None`):
243+
The number of channels in the output. If not specified, it defaults to `in_channels`.
244+
num_layers (`int`, defaults to `19`):
245+
The number of layers of dual stream DiT blocks to use.
246+
num_single_layers (`int`, defaults to `38`):
247+
The number of layers of single stream DiT blocks to use.
248+
attention_head_dim (`int`, defaults to `128`):
249+
The number of dimensions to use for each attention head.
250+
num_attention_heads (`int`, defaults to `24`):
251+
The number of attention heads to use.
252+
joint_attention_dim (`int`, defaults to `4096`):
253+
The number of dimensions to use for the joint attention (embedding/channel dimension of
254+
`encoder_hidden_states`).
255+
pooled_projection_dim (`int`, defaults to `768`):
256+
The number of dimensions to use for the pooled projection.
257+
guidance_embeds (`bool`, defaults to `False`):
258+
Whether to use guidance embeddings for guidance-distilled variant of the model.
259+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
260+
The dimensions to use for the rotary positional embeddings.
240261
"""
241262

242263
_supports_gradient_checkpointing = True
@@ -259,39 +280,39 @@ def __init__(
259280
):
260281
super().__init__()
261282
self.out_channels = out_channels or in_channels
262-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
283+
self.inner_dim = num_attention_heads * attention_head_dim
263284

264285
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
265286

266287
text_time_guidance_cls = (
267288
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
268289
)
269290
self.time_text_embed = text_time_guidance_cls(
270-
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
291+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
271292
)
272293

273-
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
274-
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
294+
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
295+
self.x_embedder = nn.Linear(in_channels, self.inner_dim)
275296

276297
self.transformer_blocks = nn.ModuleList(
277298
[
278299
FluxTransformerBlock(
279300
dim=self.inner_dim,
280-
num_attention_heads=self.config.num_attention_heads,
281-
attention_head_dim=self.config.attention_head_dim,
301+
num_attention_heads=num_attention_heads,
302+
attention_head_dim=attention_head_dim,
282303
)
283-
for i in range(self.config.num_layers)
304+
for _ in range(num_layers)
284305
]
285306
)
286307

287308
self.single_transformer_blocks = nn.ModuleList(
288309
[
289310
FluxSingleTransformerBlock(
290311
dim=self.inner_dim,
291-
num_attention_heads=self.config.num_attention_heads,
292-
attention_head_dim=self.config.attention_head_dim,
312+
num_attention_heads=num_attention_heads,
313+
attention_head_dim=attention_head_dim,
293314
)
294-
for i in range(self.config.num_single_layers)
315+
for _ in range(num_single_layers)
295316
]
296317
)
297318

@@ -418,16 +439,16 @@ def forward(
418439
controlnet_single_block_samples=None,
419440
return_dict: bool = True,
420441
controlnet_blocks_repeat: bool = False,
421-
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
442+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
422443
"""
423444
The [`FluxTransformer2DModel`] forward method.
424445
425446
Args:
426-
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
447+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
427448
Input `hidden_states`.
428-
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
449+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
429450
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
430-
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
451+
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
431452
from the embeddings of input conditions.
432453
timestep ( `torch.LongTensor`):
433454
Used to indicate denoising step.

0 commit comments

Comments
 (0)