diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index dc2eb26f9d30..f5e92700b2f3 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -85,11 +85,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): def forward( self, - hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - joint_attention_kwargs=None, - ): + hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) @@ -117,15 +117,22 @@ class FluxTransformerBlock(nn.Module): Reference: https://arxiv.org/abs/2403.03206 - Parameters: - dim (`int`): The number of channels in the input and output. - num_attention_heads (`int`): The number of heads to use for multi-head attention. - attention_head_dim (`int`): The number of channels in each head. - context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the - processing of `context` conditions. + Args: + dim (`int`): + The embedding dimension of the block. + num_attention_heads (`int`): + The number of attention heads to use. + attention_head_dim (`int`): + The number of dimensions to use for each attention head. + qk_norm (`str`, defaults to `"rms_norm"`): + The normalization to use for the query and key tensors. + eps (`float`, defaults to `1e-6`): + The epsilon value to use for the normalization. """ - def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + ): super().__init__() self.norm1 = AdaLayerNormZero(dim) @@ -164,12 +171,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_no def forward( self, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor, - temb: torch.FloatTensor, - image_rotary_emb=None, - joint_attention_kwargs=None, - ): + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( @@ -227,16 +234,30 @@ class FluxTransformer2DModel( Reference: https://blackforestlabs.ai/announcing-black-forest-labs/ - Parameters: - patch_size (`int`): Patch size to turn the input data into small patches. - in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. - num_layers (`int`, *optional*, defaults to 18): The number of layers of MMDiT blocks to use. - num_single_layers (`int`, *optional*, defaults to 18): The number of layers of single DiT blocks to use. - attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. - num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. - joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. - pooled_projection_dim (`int`): Number of dimensions to use when projecting the `pooled_projections`. - guidance_embeds (`bool`, defaults to False): Whether to use guidance embeddings. + Args: + patch_size (`int`, defaults to `1`): + Patch size to turn the input data into small patches. + in_channels (`int`, defaults to `64`): + The number of channels in the input. + out_channels (`int`, *optional*, defaults to `None`): + The number of channels in the output. If not specified, it defaults to `in_channels`. + num_layers (`int`, defaults to `19`): + The number of layers of dual stream DiT blocks to use. + num_single_layers (`int`, defaults to `38`): + The number of layers of single stream DiT blocks to use. + attention_head_dim (`int`, defaults to `128`): + The number of dimensions to use for each attention head. + num_attention_heads (`int`, defaults to `24`): + The number of attention heads to use. + joint_attention_dim (`int`, defaults to `4096`): + The number of dimensions to use for the joint attention (embedding/channel dimension of + `encoder_hidden_states`). + pooled_projection_dim (`int`, defaults to `768`): + The number of dimensions to use for the pooled projection. + guidance_embeds (`bool`, defaults to `False`): + Whether to use guidance embeddings for guidance-distilled variant of the model. + axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): + The dimensions to use for the rotary positional embeddings. """ _supports_gradient_checkpointing = True @@ -259,7 +280,7 @@ def __init__( ): super().__init__() self.out_channels = out_channels or in_channels - self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim + self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) @@ -267,20 +288,20 @@ def __init__( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) - self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) - self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, ) - for i in range(self.config.num_layers) + for _ in range(num_layers) ] ) @@ -288,10 +309,10 @@ def __init__( [ FluxSingleTransformerBlock( dim=self.inner_dim, - num_attention_heads=self.config.num_attention_heads, - attention_head_dim=self.config.attention_head_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, ) - for i in range(self.config.num_single_layers) + for _ in range(num_single_layers) ] ) @@ -418,16 +439,16 @@ def forward( controlnet_single_block_samples=None, return_dict: bool = True, controlnet_blocks_repeat: bool = False, - ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. Args: - hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): Input `hidden_states`. - encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. - pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step.