2121
2222from ...configuration_utils import ConfigMixin , register_to_config
2323from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
24- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
24+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
2525from ...utils .torch_utils import maybe_allow_in_graph
26- from ..attention import FeedForward
27- from ..attention_processor import Attention
26+ from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
2827from ..cache_utils import CacheMixin
2928from ..embeddings import PixArtAlphaTextProjection , TimestepEmbedding , Timesteps , get_1d_rotary_pos_embed
3029from ..modeling_outputs import Transformer2DModelOutput
3534logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
3635
3736
38- class WanAttnProcessor2_0 :
37+ class WanAttnProcessor :
3938 def __init__ (self ):
4039 if not hasattr (F , "scaled_dot_product_attention" ):
41- raise ImportError ("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
40+ raise ImportError (
41+ "WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
42+ )
43+
44+ def get_qkv_projections (
45+ self , attn : "WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor
46+ ):
47+ # encoder_hidden_states is only passed for cross-attention
48+ if encoder_hidden_states is None :
49+ encoder_hidden_states = hidden_states
50+
51+ if attn .fused_projections :
52+ if attn .cross_attention_dim_head is None :
53+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
54+ query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
55+ else :
56+ # In cross-attention layers, we can only fuse the KV projections into a single linear
57+ query = attn .to_q (hidden_states )
58+ key , value = attn .to_kv (encoder_hidden_states ).chunk (2 , dim = - 1 )
59+ else :
60+ query = attn .to_q (hidden_states )
61+ key = attn .to_k (encoder_hidden_states )
62+ value = attn .to_v (encoder_hidden_states )
63+ return query , key , value
64+
65+ def get_added_kv_projections (self , attn : "WanAttention" , encoder_hidden_states_img : torch .Tensor ):
66+ if attn .fused_projections :
67+ key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
68+ else :
69+ key_img = attn .add_k_proj (encoder_hidden_states_img )
70+ value_img = attn .add_v_proj (encoder_hidden_states_img )
71+ return key_img , value_img
4272
4373 def __call__ (
4474 self ,
45- attn : Attention ,
75+ attn : "WanAttention" ,
4676 hidden_states : torch .Tensor ,
4777 encoder_hidden_states : Optional [torch .Tensor ] = None ,
4878 attention_mask : Optional [torch .Tensor ] = None ,
49- rotary_emb : Optional [torch .Tensor ] = None ,
79+ rotary_emb : Optional [Tuple [ torch .Tensor , torch . Tensor ] ] = None ,
5080 ) -> torch .Tensor :
5181 encoder_hidden_states_img = None
5282 if attn .add_k_proj is not None :
5383 # 512 is the context length of the text encoder, hardcoded for now
5484 image_context_length = encoder_hidden_states .shape [1 ] - 512
5585 encoder_hidden_states_img = encoder_hidden_states [:, :image_context_length ]
5686 encoder_hidden_states = encoder_hidden_states [:, image_context_length :]
57- if encoder_hidden_states is None :
58- encoder_hidden_states = hidden_states
5987
60- query = attn .to_q (hidden_states )
61- key = attn .to_k (encoder_hidden_states )
62- value = attn .to_v (encoder_hidden_states )
88+ query , key , value = self .get_qkv_projections (attn , hidden_states , encoder_hidden_states )
6389
64- if attn .norm_q is not None :
65- query = attn .norm_q (query )
66- if attn .norm_k is not None :
67- key = attn .norm_k (key )
90+ query = attn .norm_q (query )
91+ key = attn .norm_k (key )
6892
6993 query = query .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
7094 key = key .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
@@ -92,9 +116,8 @@ def apply_rotary_emb(
92116 # I2V task
93117 hidden_states_img = None
94118 if encoder_hidden_states_img is not None :
95- key_img = attn . add_k_proj ( encoder_hidden_states_img )
119+ key_img , value_img = self . get_added_kv_projections ( attn , encoder_hidden_states_img )
96120 key_img = attn .norm_added_k (key_img )
97- value_img = attn .add_v_proj (encoder_hidden_states_img )
98121
99122 key_img = key_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
100123 value_img = value_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
@@ -119,6 +142,119 @@ def apply_rotary_emb(
119142 return hidden_states
120143
121144
145+ class WanAttnProcessor2_0 :
146+ def __new__ (cls , * args , ** kwargs ):
147+ deprecation_message = (
148+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
149+ "Please use WanAttnProcessor instead. "
150+ )
151+ deprecate ("WanAttnProcessor2_0" , "1.0.0" , deprecation_message , standard_warn = False )
152+ return WanAttnProcessor (* args , ** kwargs )
153+
154+
155+ class WanAttention (torch .nn .Module , AttentionModuleMixin ):
156+ _default_processor_cls = WanAttnProcessor
157+ _available_processors = [WanAttnProcessor ]
158+
159+ def __init__ (
160+ self ,
161+ dim : int ,
162+ heads : int = 8 ,
163+ dim_head : int = 64 ,
164+ eps : float = 1e-5 ,
165+ dropout : float = 0.0 ,
166+ added_kv_proj_dim : Optional [int ] = None ,
167+ cross_attention_dim_head : Optional [int ] = None ,
168+ processor = None ,
169+ ):
170+ super ().__init__ ()
171+
172+ self .inner_dim = dim_head * heads
173+ self .heads = heads
174+ self .added_kv_proj_dim = added_kv_proj_dim
175+ self .cross_attention_dim_head = cross_attention_dim_head
176+ self .kv_inner_dim = self .inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
177+
178+ self .to_q = torch .nn .Linear (dim , self .inner_dim , bias = True )
179+ self .to_k = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
180+ self .to_v = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
181+ self .to_out = torch .nn .ModuleList (
182+ [
183+ torch .nn .Linear (self .inner_dim , dim , bias = True ),
184+ torch .nn .Dropout (dropout ),
185+ ]
186+ )
187+ self .norm_q = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
188+ self .norm_k = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
189+
190+ self .add_k_proj = self .add_v_proj = None
191+ if added_kv_proj_dim is not None :
192+ self .add_k_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
193+ self .add_v_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
194+ self .norm_added_k = torch .nn .RMSNorm (dim_head * heads , eps = eps )
195+
196+ self .set_processor (processor )
197+
198+ def fuse_projections (self ):
199+ if getattr (self , "fused_projections" , False ):
200+ return
201+
202+ if self .cross_attention_dim_head is None :
203+ concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
204+ concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
205+ out_features , in_features = concatenated_weights .shape
206+ with torch .device ("meta" ):
207+ self .to_qkv = nn .Linear (in_features , out_features , bias = True )
208+ self .to_qkv .load_state_dict (
209+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
210+ )
211+ else :
212+ concatenated_weights = torch .cat ([self .to_k .weight .data , self .to_v .weight .data ])
213+ concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
214+ out_features , in_features = concatenated_weights .shape
215+ with torch .device ("meta" ):
216+ self .to_kv = nn .Linear (in_features , out_features , bias = True )
217+ self .to_kv .load_state_dict (
218+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
219+ )
220+
221+ if self .added_kv_proj_dim is not None :
222+ concatenated_weights = torch .cat ([self .add_k_proj .weight .data , self .add_v_proj .weight .data ])
223+ concatenated_bias = torch .cat ([self .add_k_proj .bias .data , self .add_v_proj .bias .data ])
224+ out_features , in_features = concatenated_weights .shape
225+ with torch .device ("meta" ):
226+ self .to_added_kv = nn .Linear (in_features , out_features , bias = True )
227+ self .to_added_kv .load_state_dict (
228+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
229+ )
230+
231+ self .fused_projections = True
232+
233+ @torch .no_grad ()
234+ def unfuse_projections (self ):
235+ if not getattr (self , "fused_projections" , False ):
236+ return
237+
238+ if hasattr (self , "to_qkv" ):
239+ delattr (self , "to_qkv" )
240+ if hasattr (self , "to_kv" ):
241+ delattr (self , "to_kv" )
242+ if hasattr (self , "to_added_kv" ):
243+ delattr (self , "to_added_kv" )
244+
245+ self .fused_projections = False
246+
247+ def forward (
248+ self ,
249+ hidden_states : torch .Tensor ,
250+ encoder_hidden_states : torch .Tensor ,
251+ attention_mask : Optional [torch .Tensor ] = None ,
252+ rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
253+ ** kwargs ,
254+ ) -> torch .Tensor :
255+ return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , rotary_emb , ** kwargs )
256+
257+
122258class WanImageEmbedding (torch .nn .Module ):
123259 def __init__ (self , in_features : int , out_features : int , pos_embed_seq_len = None ):
124260 super ().__init__ ()
@@ -266,33 +402,24 @@ def __init__(
266402
267403 # 1. Self-attention
268404 self .norm1 = FP32LayerNorm (dim , eps , elementwise_affine = False )
269- self .attn1 = Attention (
270- query_dim = dim ,
405+ self .attn1 = WanAttention (
406+ dim = dim ,
271407 heads = num_heads ,
272- kv_heads = num_heads ,
273408 dim_head = dim // num_heads ,
274- qk_norm = qk_norm ,
275409 eps = eps ,
276- bias = True ,
277- cross_attention_dim = None ,
278- out_bias = True ,
279- processor = WanAttnProcessor2_0 (),
410+ cross_attention_dim_head = None ,
411+ processor = WanAttnProcessor (),
280412 )
281413
282414 # 2. Cross-attention
283- self .attn2 = Attention (
284- query_dim = dim ,
415+ self .attn2 = WanAttention (
416+ dim = dim ,
285417 heads = num_heads ,
286- kv_heads = num_heads ,
287418 dim_head = dim // num_heads ,
288- qk_norm = qk_norm ,
289419 eps = eps ,
290- bias = True ,
291- cross_attention_dim = None ,
292- out_bias = True ,
293420 added_kv_proj_dim = added_kv_proj_dim ,
294- added_proj_bias = True ,
295- processor = WanAttnProcessor2_0 (),
421+ cross_attention_dim_head = dim // num_heads ,
422+ processor = WanAttnProcessor (),
296423 )
297424 self .norm2 = FP32LayerNorm (dim , eps , elementwise_affine = True ) if cross_attn_norm else nn .Identity ()
298425
@@ -315,12 +442,12 @@ def forward(
315442
316443 # 1. Self-attention
317444 norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
318- attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
445+ attn_output = self .attn1 (norm_hidden_states , None , None , rotary_emb )
319446 hidden_states = (hidden_states .float () + attn_output * gate_msa ).type_as (hidden_states )
320447
321448 # 2. Cross-attention
322449 norm_hidden_states = self .norm2 (hidden_states .float ()).type_as (hidden_states )
323- attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
450+ attn_output = self .attn2 (norm_hidden_states , encoder_hidden_states , None , None )
324451 hidden_states = hidden_states + attn_output
325452
326453 # 3. Feed-forward
@@ -333,7 +460,9 @@ def forward(
333460 return hidden_states
334461
335462
336- class WanTransformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
463+ class WanTransformer3DModel (
464+ ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin , AttentionMixin
465+ ):
337466 r"""
338467 A Transformer model for video-like data used in the Wan model.
339468
0 commit comments