21
21
22
22
from ...configuration_utils import ConfigMixin , register_to_config
23
23
from ...loaders import FromOriginalModelMixin , PeftAdapterMixin
24
- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
24
+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
25
25
from ...utils .torch_utils import maybe_allow_in_graph
26
- from ..attention import FeedForward
27
- from ..attention_processor import Attention
26
+ from ..attention import AttentionMixin , AttentionModuleMixin , FeedForward
27
+ from ..attention_dispatch import dispatch_attention_fn
28
28
from ..cache_utils import CacheMixin
29
29
from ..embeddings import PixArtAlphaTextProjection , TimestepEmbedding , Timesteps , get_1d_rotary_pos_embed
30
30
from ..modeling_outputs import Transformer2DModelOutput
35
35
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
36
36
37
37
38
- class WanAttnProcessor2_0 :
38
+ def _get_qkv_projections (attn : "WanAttention" , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ):
39
+ # encoder_hidden_states is only passed for cross-attention
40
+ if encoder_hidden_states is None :
41
+ encoder_hidden_states = hidden_states
42
+
43
+ if attn .fused_projections :
44
+ if attn .cross_attention_dim_head is None :
45
+ # In self-attention layers, we can fuse the entire QKV projection into a single linear
46
+ query , key , value = attn .to_qkv (hidden_states ).chunk (3 , dim = - 1 )
47
+ else :
48
+ # In cross-attention layers, we can only fuse the KV projections into a single linear
49
+ query = attn .to_q (hidden_states )
50
+ key , value = attn .to_kv (encoder_hidden_states ).chunk (2 , dim = - 1 )
51
+ else :
52
+ query = attn .to_q (hidden_states )
53
+ key = attn .to_k (encoder_hidden_states )
54
+ value = attn .to_v (encoder_hidden_states )
55
+ return query , key , value
56
+
57
+
58
+ def _get_added_kv_projections (attn : "WanAttention" , encoder_hidden_states_img : torch .Tensor ):
59
+ if attn .fused_projections :
60
+ key_img , value_img = attn .to_added_kv (encoder_hidden_states_img ).chunk (2 , dim = - 1 )
61
+ else :
62
+ key_img = attn .add_k_proj (encoder_hidden_states_img )
63
+ value_img = attn .add_v_proj (encoder_hidden_states_img )
64
+ return key_img , value_img
65
+
66
+
67
+ class WanAttnProcessor :
68
+ _attention_backend = None
69
+
39
70
def __init__ (self ):
40
71
if not hasattr (F , "scaled_dot_product_attention" ):
41
- raise ImportError ("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
72
+ raise ImportError (
73
+ "WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
74
+ )
42
75
43
76
def __call__ (
44
77
self ,
45
- attn : Attention ,
78
+ attn : "WanAttention" ,
46
79
hidden_states : torch .Tensor ,
47
80
encoder_hidden_states : Optional [torch .Tensor ] = None ,
48
81
attention_mask : Optional [torch .Tensor ] = None ,
49
- rotary_emb : Optional [torch .Tensor ] = None ,
82
+ rotary_emb : Optional [Tuple [ torch .Tensor , torch . Tensor ] ] = None ,
50
83
) -> torch .Tensor :
51
84
encoder_hidden_states_img = None
52
85
if attn .add_k_proj is not None :
53
86
# 512 is the context length of the text encoder, hardcoded for now
54
87
image_context_length = encoder_hidden_states .shape [1 ] - 512
55
88
encoder_hidden_states_img = encoder_hidden_states [:, :image_context_length ]
56
89
encoder_hidden_states = encoder_hidden_states [:, image_context_length :]
57
- if encoder_hidden_states is None :
58
- encoder_hidden_states = hidden_states
59
90
60
- query = attn .to_q (hidden_states )
61
- key = attn .to_k (encoder_hidden_states )
62
- value = attn .to_v (encoder_hidden_states )
91
+ query , key , value = _get_qkv_projections (attn , hidden_states , encoder_hidden_states )
63
92
64
- if attn .norm_q is not None :
65
- query = attn .norm_q (query )
66
- if attn .norm_k is not None :
67
- key = attn .norm_k (key )
93
+ query = attn .norm_q (query )
94
+ key = attn .norm_k (key )
68
95
69
- query = query .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
70
- key = key .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
71
- value = value .unflatten (2 , (attn .heads , - 1 )). transpose ( 1 , 2 )
96
+ query = query .unflatten (2 , (attn .heads , - 1 ))
97
+ key = key .unflatten (2 , (attn .heads , - 1 ))
98
+ value = value .unflatten (2 , (attn .heads , - 1 ))
72
99
73
100
if rotary_emb is not None :
74
101
@@ -77,8 +104,7 @@ def apply_rotary_emb(
77
104
freqs_cos : torch .Tensor ,
78
105
freqs_sin : torch .Tensor ,
79
106
):
80
- x = hidden_states .view (* hidden_states .shape [:- 1 ], - 1 , 2 )
81
- x1 , x2 = x [..., 0 ], x [..., 1 ]
107
+ x1 , x2 = hidden_states .unflatten (- 1 , (- 1 , 2 )).unbind (- 1 )
82
108
cos = freqs_cos [..., 0 ::2 ]
83
109
sin = freqs_sin [..., 1 ::2 ]
84
110
out = torch .empty_like (hidden_states )
@@ -92,23 +118,34 @@ def apply_rotary_emb(
92
118
# I2V task
93
119
hidden_states_img = None
94
120
if encoder_hidden_states_img is not None :
95
- key_img = attn . add_k_proj ( encoder_hidden_states_img )
121
+ key_img , value_img = _get_added_kv_projections ( attn , encoder_hidden_states_img )
96
122
key_img = attn .norm_added_k (key_img )
97
- value_img = attn .add_v_proj (encoder_hidden_states_img )
98
-
99
- key_img = key_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
100
- value_img = value_img .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
101
123
102
- hidden_states_img = F .scaled_dot_product_attention (
103
- query , key_img , value_img , attn_mask = None , dropout_p = 0.0 , is_causal = False
124
+ key_img = key_img .unflatten (2 , (attn .heads , - 1 ))
125
+ value_img = value_img .unflatten (2 , (attn .heads , - 1 ))
126
+
127
+ hidden_states_img = dispatch_attention_fn (
128
+ query ,
129
+ key_img ,
130
+ value_img ,
131
+ attn_mask = None ,
132
+ dropout_p = 0.0 ,
133
+ is_causal = False ,
134
+ backend = self ._attention_backend ,
104
135
)
105
- hidden_states_img = hidden_states_img .transpose ( 1 , 2 ). flatten (2 , 3 )
136
+ hidden_states_img = hidden_states_img .flatten (2 , 3 )
106
137
hidden_states_img = hidden_states_img .type_as (query )
107
138
108
- hidden_states = F .scaled_dot_product_attention (
109
- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
139
+ hidden_states = dispatch_attention_fn (
140
+ query ,
141
+ key ,
142
+ value ,
143
+ attn_mask = attention_mask ,
144
+ dropout_p = 0.0 ,
145
+ is_causal = False ,
146
+ backend = self ._attention_backend ,
110
147
)
111
- hidden_states = hidden_states .transpose ( 1 , 2 ). flatten (2 , 3 )
148
+ hidden_states = hidden_states .flatten (2 , 3 )
112
149
hidden_states = hidden_states .type_as (query )
113
150
114
151
if hidden_states_img is not None :
@@ -119,6 +156,119 @@ def apply_rotary_emb(
119
156
return hidden_states
120
157
121
158
159
+ class WanAttnProcessor2_0 :
160
+ def __new__ (cls , * args , ** kwargs ):
161
+ deprecation_message = (
162
+ "The WanAttnProcessor2_0 class is deprecated and will be removed in a future version. "
163
+ "Please use WanAttnProcessor instead. "
164
+ )
165
+ deprecate ("WanAttnProcessor2_0" , "1.0.0" , deprecation_message , standard_warn = False )
166
+ return WanAttnProcessor (* args , ** kwargs )
167
+
168
+
169
+ class WanAttention (torch .nn .Module , AttentionModuleMixin ):
170
+ _default_processor_cls = WanAttnProcessor
171
+ _available_processors = [WanAttnProcessor ]
172
+
173
+ def __init__ (
174
+ self ,
175
+ dim : int ,
176
+ heads : int = 8 ,
177
+ dim_head : int = 64 ,
178
+ eps : float = 1e-5 ,
179
+ dropout : float = 0.0 ,
180
+ added_kv_proj_dim : Optional [int ] = None ,
181
+ cross_attention_dim_head : Optional [int ] = None ,
182
+ processor = None ,
183
+ ):
184
+ super ().__init__ ()
185
+
186
+ self .inner_dim = dim_head * heads
187
+ self .heads = heads
188
+ self .added_kv_proj_dim = added_kv_proj_dim
189
+ self .cross_attention_dim_head = cross_attention_dim_head
190
+ self .kv_inner_dim = self .inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
191
+
192
+ self .to_q = torch .nn .Linear (dim , self .inner_dim , bias = True )
193
+ self .to_k = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
194
+ self .to_v = torch .nn .Linear (dim , self .kv_inner_dim , bias = True )
195
+ self .to_out = torch .nn .ModuleList (
196
+ [
197
+ torch .nn .Linear (self .inner_dim , dim , bias = True ),
198
+ torch .nn .Dropout (dropout ),
199
+ ]
200
+ )
201
+ self .norm_q = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
202
+ self .norm_k = torch .nn .RMSNorm (dim_head * heads , eps = eps , elementwise_affine = True )
203
+
204
+ self .add_k_proj = self .add_v_proj = None
205
+ if added_kv_proj_dim is not None :
206
+ self .add_k_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
207
+ self .add_v_proj = torch .nn .Linear (added_kv_proj_dim , self .inner_dim , bias = True )
208
+ self .norm_added_k = torch .nn .RMSNorm (dim_head * heads , eps = eps )
209
+
210
+ self .set_processor (processor )
211
+
212
+ def fuse_projections (self ):
213
+ if getattr (self , "fused_projections" , False ):
214
+ return
215
+
216
+ if self .cross_attention_dim_head is None :
217
+ concatenated_weights = torch .cat ([self .to_q .weight .data , self .to_k .weight .data , self .to_v .weight .data ])
218
+ concatenated_bias = torch .cat ([self .to_q .bias .data , self .to_k .bias .data , self .to_v .bias .data ])
219
+ out_features , in_features = concatenated_weights .shape
220
+ with torch .device ("meta" ):
221
+ self .to_qkv = nn .Linear (in_features , out_features , bias = True )
222
+ self .to_qkv .load_state_dict (
223
+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
224
+ )
225
+ else :
226
+ concatenated_weights = torch .cat ([self .to_k .weight .data , self .to_v .weight .data ])
227
+ concatenated_bias = torch .cat ([self .to_k .bias .data , self .to_v .bias .data ])
228
+ out_features , in_features = concatenated_weights .shape
229
+ with torch .device ("meta" ):
230
+ self .to_kv = nn .Linear (in_features , out_features , bias = True )
231
+ self .to_kv .load_state_dict (
232
+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
233
+ )
234
+
235
+ if self .added_kv_proj_dim is not None :
236
+ concatenated_weights = torch .cat ([self .add_k_proj .weight .data , self .add_v_proj .weight .data ])
237
+ concatenated_bias = torch .cat ([self .add_k_proj .bias .data , self .add_v_proj .bias .data ])
238
+ out_features , in_features = concatenated_weights .shape
239
+ with torch .device ("meta" ):
240
+ self .to_added_kv = nn .Linear (in_features , out_features , bias = True )
241
+ self .to_added_kv .load_state_dict (
242
+ {"weight" : concatenated_weights , "bias" : concatenated_bias }, strict = True , assign = True
243
+ )
244
+
245
+ self .fused_projections = True
246
+
247
+ @torch .no_grad ()
248
+ def unfuse_projections (self ):
249
+ if not getattr (self , "fused_projections" , False ):
250
+ return
251
+
252
+ if hasattr (self , "to_qkv" ):
253
+ delattr (self , "to_qkv" )
254
+ if hasattr (self , "to_kv" ):
255
+ delattr (self , "to_kv" )
256
+ if hasattr (self , "to_added_kv" ):
257
+ delattr (self , "to_added_kv" )
258
+
259
+ self .fused_projections = False
260
+
261
+ def forward (
262
+ self ,
263
+ hidden_states : torch .Tensor ,
264
+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
265
+ attention_mask : Optional [torch .Tensor ] = None ,
266
+ rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
267
+ ** kwargs ,
268
+ ) -> torch .Tensor :
269
+ return self .processor (self , hidden_states , encoder_hidden_states , attention_mask , rotary_emb , ** kwargs )
270
+
271
+
122
272
class WanImageEmbedding (torch .nn .Module ):
123
273
def __init__ (self , in_features : int , out_features : int , pos_embed_seq_len = None ):
124
274
super ().__init__ ()
@@ -247,8 +397,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
247
397
freqs_sin_h = freqs_sin [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
248
398
freqs_sin_w = freqs_sin [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
249
399
250
- freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
251
- freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
400
+ freqs_cos = torch .cat ([freqs_cos_f , freqs_cos_h , freqs_cos_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
401
+ freqs_sin = torch .cat ([freqs_sin_f , freqs_sin_h , freqs_sin_w ], dim = - 1 ).reshape (1 , ppf * pph * ppw , 1 , - 1 )
252
402
253
403
return freqs_cos , freqs_sin
254
404
@@ -269,33 +419,24 @@ def __init__(
269
419
270
420
# 1. Self-attention
271
421
self .norm1 = FP32LayerNorm (dim , eps , elementwise_affine = False )
272
- self .attn1 = Attention (
273
- query_dim = dim ,
422
+ self .attn1 = WanAttention (
423
+ dim = dim ,
274
424
heads = num_heads ,
275
- kv_heads = num_heads ,
276
425
dim_head = dim // num_heads ,
277
- qk_norm = qk_norm ,
278
426
eps = eps ,
279
- bias = True ,
280
- cross_attention_dim = None ,
281
- out_bias = True ,
282
- processor = WanAttnProcessor2_0 (),
427
+ cross_attention_dim_head = None ,
428
+ processor = WanAttnProcessor (),
283
429
)
284
430
285
431
# 2. Cross-attention
286
- self .attn2 = Attention (
287
- query_dim = dim ,
432
+ self .attn2 = WanAttention (
433
+ dim = dim ,
288
434
heads = num_heads ,
289
- kv_heads = num_heads ,
290
435
dim_head = dim // num_heads ,
291
- qk_norm = qk_norm ,
292
436
eps = eps ,
293
- bias = True ,
294
- cross_attention_dim = None ,
295
- out_bias = True ,
296
437
added_kv_proj_dim = added_kv_proj_dim ,
297
- added_proj_bias = True ,
298
- processor = WanAttnProcessor2_0 (),
438
+ cross_attention_dim_head = dim // num_heads ,
439
+ processor = WanAttnProcessor (),
299
440
)
300
441
self .norm2 = FP32LayerNorm (dim , eps , elementwise_affine = True ) if cross_attn_norm else nn .Identity ()
301
442
@@ -332,12 +473,12 @@ def forward(
332
473
333
474
# 1. Self-attention
334
475
norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
335
- attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
476
+ attn_output = self .attn1 (norm_hidden_states , None , None , rotary_emb )
336
477
hidden_states = (hidden_states .float () + attn_output * gate_msa ).type_as (hidden_states )
337
478
338
479
# 2. Cross-attention
339
480
norm_hidden_states = self .norm2 (hidden_states .float ()).type_as (hidden_states )
340
- attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
481
+ attn_output = self .attn2 (norm_hidden_states , encoder_hidden_states , None , None )
341
482
hidden_states = hidden_states + attn_output
342
483
343
484
# 3. Feed-forward
@@ -350,7 +491,9 @@ def forward(
350
491
return hidden_states
351
492
352
493
353
- class WanTransformer3DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
494
+ class WanTransformer3DModel (
495
+ ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin , AttentionMixin
496
+ ):
354
497
r"""
355
498
A Transformer model for video-like data used in the Wan model.
356
499
0 commit comments