1818import numpy as np
1919import torch
2020import torch .nn as nn
21- import torch .nn .functional as F
2221
2322from ...configuration_utils import ConfigMixin , register_to_config
2423from ...loaders import FluxTransformer2DLoadersMixin , FromOriginalModelMixin , PeftAdapterMixin
3231)
3332from ...models .modeling_utils import ModelMixin
3433from ...models .normalization import AdaLayerNormContinuous , AdaLayerNormZero , AdaLayerNormZeroSingle
35- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
34+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
3635from ...utils .import_utils import is_torch_npu_available
3736from ...utils .torch_utils import maybe_allow_in_graph
3837from ..cache_utils import CacheMixin
4544
4645@maybe_allow_in_graph
4746class FluxSingleTransformerBlock (nn .Module ):
48- r"""
49- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
50-
51- Reference: https://arxiv.org/abs/2403.03206
52-
53- Parameters:
54- dim (`int`): The number of channels in the input and output.
55- num_attention_heads (`int`): The number of heads to use for multi-head attention.
56- attention_head_dim (`int`): The number of channels in each head.
57- context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
58- processing of `context` conditions.
59- """
60-
61- def __init__ (self , dim , num_attention_heads , attention_head_dim , mlp_ratio = 4.0 ):
47+ def __init__ (self , dim : int , num_attention_heads : int , attention_head_dim : int , mlp_ratio : float = 4.0 ):
6248 super ().__init__ ()
6349 self .mlp_hidden_dim = int (dim * mlp_ratio )
6450
@@ -68,9 +54,15 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
6854 self .proj_out = nn .Linear (dim + self .mlp_hidden_dim , dim )
6955
7056 if is_torch_npu_available ():
57+ deprecation_message = (
58+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
59+ "should be set explicitly using the `set_attn_processor` method."
60+ )
61+ deprecate ("npu_processor" , "0.34.0" , deprecation_message )
7162 processor = FluxAttnProcessor2_0_NPU ()
7263 else :
7364 processor = FluxAttnProcessor2_0 ()
65+
7466 self .attn = Attention (
7567 query_dim = dim ,
7668 cross_attention_dim = None ,
@@ -113,39 +105,14 @@ def forward(
113105
114106@maybe_allow_in_graph
115107class FluxTransformerBlock (nn .Module ):
116- r"""
117- A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
118-
119- Reference: https://arxiv.org/abs/2403.03206
120-
121- Args:
122- dim (`int`):
123- The embedding dimension of the block.
124- num_attention_heads (`int`):
125- The number of attention heads to use.
126- attention_head_dim (`int`):
127- The number of dimensions to use for each attention head.
128- qk_norm (`str`, defaults to `"rms_norm"`):
129- The normalization to use for the query and key tensors.
130- eps (`float`, defaults to `1e-6`):
131- The epsilon value to use for the normalization.
132- """
133-
134108 def __init__ (
135109 self , dim : int , num_attention_heads : int , attention_head_dim : int , qk_norm : str = "rms_norm" , eps : float = 1e-6
136110 ):
137111 super ().__init__ ()
138112
139113 self .norm1 = AdaLayerNormZero (dim )
140-
141114 self .norm1_context = AdaLayerNormZero (dim )
142115
143- if hasattr (F , "scaled_dot_product_attention" ):
144- processor = FluxAttnProcessor2_0 ()
145- else :
146- raise ValueError (
147- "The current PyTorch version does not support the `scaled_dot_product_attention` function."
148- )
149116 self .attn = Attention (
150117 query_dim = dim ,
151118 cross_attention_dim = None ,
@@ -155,7 +122,7 @@ def __init__(
155122 out_dim = dim ,
156123 context_pre_only = False ,
157124 bias = True ,
158- processor = processor ,
125+ processor = FluxAttnProcessor2_0 () ,
159126 qk_norm = qk_norm ,
160127 eps = eps ,
161128 )
@@ -166,10 +133,6 @@ def __init__(
166133 self .norm2_context = nn .LayerNorm (dim , elementwise_affine = False , eps = 1e-6 )
167134 self .ff_context = FeedForward (dim = dim , dim_out = dim , activation_fn = "gelu-approximate" )
168135
169- # let chunk size default to None
170- self ._chunk_size = None
171- self ._chunk_dim = 0
172-
173136 def forward (
174137 self ,
175138 hidden_states : torch .Tensor ,
0 commit comments