2323from ...utils import logging
2424from ..attention import FeedForward
2525from ..attention_processor import Attention
26- from ..embeddings import TimestepEmbedding , Timesteps , PixArtAlphaTextProjection
26+ from ..embeddings import TimestepEmbedding , Timesteps , PixArtAlphaTextProjection , get_1d_rotary_pos_embed
2727from ..modeling_outputs import Transformer2DModelOutput
2828from ..modeling_utils import ModelMixin
2929from ..normalization import FP32LayerNorm
@@ -45,14 +45,8 @@ def __call__(
4545 hidden_states : torch .Tensor ,
4646 encoder_hidden_states : Optional [torch .Tensor ] = None ,
4747 attention_mask : Optional [torch .Tensor ] = None ,
48- grid_sizes : Optional [torch .Tensor ] = None ,
49- freqs : Optional [torch .Tensor ] = None ,
48+ rotary_emb : Optional [torch .Tensor ] = None ,
5049 ) -> torch .Tensor :
51- batch_size , _ , _ = (
52- hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
53- )
54-
55- # i2v task
5650 encoder_hidden_states_img = None
5751 if attn .add_k_proj is not None :
5852 encoder_hidden_states_img = encoder_hidden_states [:, :257 ]
@@ -69,19 +63,20 @@ def __call__(
6963 if attn .norm_k is not None :
7064 key = attn .norm_k (key )
7165
72- query = query .unflatten (2 , (attn .heads , - 1 ))
73- key = key .unflatten (2 , (attn .heads , - 1 ))
74- value = value .unflatten (2 , (attn .heads , - 1 ))
75-
76- if grid_sizes is not None and freqs is not None :
77- query = apply_rotary_emb (query , grid_sizes , freqs )
78- key = apply_rotary_emb (key , grid_sizes , freqs )
79-
80- query = query .transpose (1 , 2 )
81- key = key .transpose (1 , 2 )
82- value = value .transpose (1 , 2 )
66+ query = query .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
67+ key = key .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
68+ value = value .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
69+
70+ if rotary_emb is not None :
71+ def apply_rotary_emb (hidden_states : torch .Tensor , freqs : torch .Tensor ):
72+ x_rotated = torch .view_as_complex (hidden_states .to (torch .float64 ).unflatten (3 , (- 1 , 2 )))
73+ x_out = torch .view_as_real (x_rotated * freqs ).flatten (3 , 4 )
74+ return x_out .type_as (hidden_states )
75+
76+ query = apply_rotary_emb (query , rotary_emb )
77+ key = apply_rotary_emb (key , rotary_emb )
8378
84- # i2v task
79+ # I2V task
8580 hidden_states_img = None
8681 if encoder_hidden_states_img is not None :
8782 key_img = attn .add_k_proj (encoder_hidden_states_img )
@@ -111,45 +106,6 @@ def __call__(
111106 return hidden_states
112107
113108
114- @torch .cuda .amp .autocast (enabled = False )
115- def rope_params (max_seq_len , dim , theta = 10000 ):
116- assert dim % 2 == 0
117- freqs = torch .outer (
118- torch .arange (max_seq_len ),
119- 1.0 / torch .pow (theta ,
120- torch .arange (0 , dim , 2 ).to (torch .float64 ).div (dim )))
121- freqs = torch .polar (torch .ones_like (freqs ), freqs )
122- return freqs
123-
124-
125- def apply_rotary_emb (hidden_states : torch .Tensor , grid_sizes , freqs ):
126- n , c = hidden_states .size (2 ), hidden_states .size (3 ) // 2
127-
128- # split freqs
129- freqs = freqs .split ([c - 2 * (c // 3 ), c // 3 , c // 3 ], dim = 1 )
130-
131- # loop over samples
132- output = []
133- for i , (f , h , w ) in enumerate (grid_sizes .tolist ()):
134- seq_len = f * h * w
135-
136- # precompute multipliers
137- x_i = torch .view_as_complex (hidden_states [i , :seq_len ].to (torch .float64 ).reshape (seq_len , n , - 1 , 2 ))
138- freqs_i = torch .cat ([
139- freqs [0 ][:f ].view (f , 1 , 1 , - 1 ).expand (f , h , w , - 1 ),
140- freqs [1 ][:h ].view (1 , h , 1 , - 1 ).expand (f , h , w , - 1 ),
141- freqs [2 ][:w ].view (1 , 1 , w , - 1 ).expand (f , h , w , - 1 )
142- ], dim = - 1 ).reshape (seq_len , 1 , - 1 )
143-
144- # apply rotary embedding
145- x_i = torch .view_as_real (x_i * freqs_i ).flatten (2 )
146- x_i = torch .cat ([x_i , hidden_states [i , seq_len :]])
147-
148- # append to collection
149- output .append (x_i )
150- return torch .stack (output ).type_as (hidden_states )
151-
152-
153109class WanImageEmbedding (torch .nn .Module ):
154110 def __init__ (self , in_features : int , out_features : int ):
155111 super ().__init__ ()
@@ -188,10 +144,8 @@ def __init__(
188144
189145 def forward (self , timestep : torch .Tensor , encoder_hidden_states : torch .Tensor , encoder_hidden_states_image : Optional [torch .Tensor ] = None ):
190146 timestep = self .timesteps_proj (timestep )
191- with torch .amp .autocast (str (encoder_hidden_states .device ), dtype = torch .float32 ):
192- temb = self .time_embedder (timestep )
193- timestep_proj = self .time_proj (self .act_fn (temb ))
194- assert temb .dtype == torch .float32 and timestep_proj .dtype == torch .float32
147+ temb = self .time_embedder (timestep .type_as (encoder_hidden_states ))
148+ timestep_proj = self .time_proj (self .act_fn (temb ))
195149
196150 encoder_hidden_states = self .text_embedder (encoder_hidden_states )
197151 if encoder_hidden_states_image is not None :
@@ -200,15 +154,49 @@ def forward(self, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor, e
200154 return temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image
201155
202156
157+ class WanRotaryPosEmbed (nn .Module ):
158+ def __init__ (self , attention_head_dim : int , patch_size : Tuple [int , int , int ], max_seq_len : int , theta : float = 10000.0 ):
159+ super ().__init__ ()
160+
161+ self .attention_head_dim = attention_head_dim
162+ self .patch_size = patch_size
163+ self .max_seq_len = max_seq_len
164+
165+ h_dim = w_dim = 2 * (attention_head_dim // 6 )
166+ t_dim = attention_head_dim - h_dim - w_dim
167+
168+ freqs = []
169+ for dim in [t_dim , h_dim , w_dim ]:
170+ freq = get_1d_rotary_pos_embed (dim , max_seq_len , theta , use_real = False , repeat_interleave_real = False , freqs_dtype = torch .float64 )
171+ freqs .append (freq )
172+ self .freqs = torch .cat (freqs , dim = 1 )
173+
174+ def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
175+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
176+ p_t , p_h , p_w = self .patch_size
177+ ppf , pph , ppw = num_frames // p_t , height // p_h , width // p_w
178+
179+ self .freqs = self .freqs .to (hidden_states .device )
180+ freqs = self .freqs .split_with_sizes (
181+ [self .attention_head_dim // 2 - 2 * (self .attention_head_dim // 6 ), self .attention_head_dim // 6 , self .attention_head_dim // 6 ], dim = 1
182+ )
183+
184+ freqs_f = freqs [0 ][:ppf ].view (ppf , 1 , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
185+ freqs_h = freqs [1 ][:pph ].view (1 , pph , 1 , - 1 ).expand (ppf , pph , ppw , - 1 )
186+ freqs_w = freqs [2 ][:ppw ].view (1 , 1 , ppw , - 1 ).expand (ppf , pph , ppw , - 1 )
187+ freqs = torch .cat ([freqs_f , freqs_h , freqs_w ], dim = - 1 ).reshape (1 , 1 , ppf * pph * ppw , - 1 )
188+ return freqs
189+
190+
203191class WanTransformerBlock (nn .Module ):
204192 def __init__ (self ,
205- dim ,
206- ffn_dim ,
207- num_heads ,
208- qk_norm = True ,
209- cross_attn_norm = False ,
210- eps = 1e-6 ,
211- added_kv_proj_dim = None
193+ dim : int ,
194+ ffn_dim : int ,
195+ num_heads : int ,
196+ qk_norm : str = "rms_norm_across_heads" ,
197+ cross_attn_norm : bool = False ,
198+ eps : float = 1e-6 ,
199+ added_kv_proj_dim : Optional [ int ] = None
212200 ):
213201 super ().__init__ ()
214202 self .dim = dim
@@ -248,54 +236,37 @@ def __init__(self,
248236 added_proj_bias = True ,
249237 processor = WanAttnProcessor2_0 (),
250238 )
251-
252239 self .norm2 = FP32LayerNorm (dim , eps , elementwise_affine = True ) if cross_attn_norm else nn .Identity ()
253-
254- self .ffn = nn .Sequential (
255- nn .Linear (dim , ffn_dim ), nn .GELU (approximate = 'tanh' ),
256- nn .Linear (ffn_dim , dim )
257- )
240+
241+ # 3. Feed-forward
242+ self .ffn = FeedForward (dim , inner_dim = ffn_dim , activation_fn = "gelu-approximate" )
258243 self .norm3 = FP32LayerNorm (dim , eps , elementwise_affine = False )
259244
260245 self .scale_shift_table = nn .Parameter (torch .randn (1 , 6 , dim ) / dim ** 0.5 )
261246
262247 def forward (
263248 self ,
264249 hidden_states : torch .Tensor ,
265- temb : torch .Tensor ,
266250 encoder_hidden_states : torch .Tensor ,
267- grid_sizes ,
268- freqs ,
251+ temb : torch . Tensor ,
252+ rotary_emb : torch . Tensor ,
269253 ) -> torch .Tensor :
270- assert temb .dtype == torch .float32
271- with torch .amp .autocast (str (temb .device ), dtype = torch .float32 ):
272- temb = (self .scale_shift_table + temb ).chunk (6 , dim = 1 )
273- assert temb [0 ].dtype == torch .float32
254+ shift_msa , scale_msa , gate_msa , c_shift_msa , c_scale_msa , c_gate_msa = (self .scale_shift_table + temb ).chunk (6 , dim = 1 )
274255
275256 # 1. Self-attention
276- attn_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + temb [1 ]) + temb [0 ]).type_as (hidden_states )
277-
278- attn_hidden_states = self .attn1 (
279- hidden_states = attn_hidden_states ,
280- grid_sizes = grid_sizes ,
281- freqs = freqs ,
282- )
283- hidden_states = (hidden_states .float () + attn_hidden_states .float () * temb [2 ]).type_as (hidden_states )
257+ norm_hidden_states = (self .norm1 (hidden_states .float ()) * (1 + scale_msa ) + shift_msa ).type_as (hidden_states )
258+ attn_output = self .attn1 (hidden_states = norm_hidden_states , rotary_emb = rotary_emb )
259+ hidden_states = hidden_states + attn_output * gate_msa
284260
285261 # 2. Cross-attention
286- attn_hidden_states = self .norm2 (hidden_states )
287- attn_hidden_states = self .attn2 (
288- hidden_states = attn_hidden_states ,
289- encoder_hidden_states = encoder_hidden_states ,
290- grid_sizes = None ,
291- freqs = None ,
292- )
293- hidden_states = hidden_states + attn_hidden_states
262+ norm_hidden_states = self .norm2 (hidden_states )
263+ attn_output = self .attn2 (hidden_states = norm_hidden_states , encoder_hidden_states = encoder_hidden_states )
264+ hidden_states = hidden_states + attn_output
294265
295266 # 3. Feed-forward
296- ffn_hidden_states = (self .norm3 (hidden_states ) .float () * (1 + temb [ 4 ] ) + temb [ 3 ] ).type_as (hidden_states )
297- ffn_hidden_states = self .ffn (ffn_hidden_states )
298- hidden_states = ( hidden_states . float () + ffn_hidden_states . float () * temb [ 5 ]). type_as ( hidden_states )
267+ norm_hidden_states = (self .norm3 (hidden_states .float ()) * (1 + c_scale_msa ) + c_shift_msa ).type_as (hidden_states )
268+ ff_output = self .ffn (norm_hidden_states )
269+ hidden_states = hidden_states + ff_output * c_gate_msa
299270
300271 return hidden_states
301272
@@ -338,7 +309,7 @@ class WanTransformer3DModel(ModelMixin, ConfigMixin):
338309 """
339310
340311 _supports_gradient_checkpointing = True
341- _skip_layerwise_casting_patterns = ["patch_embedding" , "text_embedding" , "time_embedding" , "time_projection " , "norm" ]
312+ _skip_layerwise_casting_patterns = ["patch_embedding" , "condition_embedder " , "norm" ]
342313 _no_split_modules = ["WanTransformerBlock" ]
343314
344315 @register_to_config
@@ -358,16 +329,15 @@ def __init__(
358329 eps : float = 1e-6 ,
359330 image_embedding_dim : Optional [int ] = None ,
360331 added_kv_proj_dim : Optional [int ] = None ,
332+ rope_max_seq_len : int = 1024 ,
361333 ) -> None :
362334 super ().__init__ ()
363335
364336 inner_dim = num_attention_heads * attention_head_dim
365337 out_channels = out_channels or in_channels
366338
367- self .out_channels = out_channels
368- self .patch_size = patch_size
369-
370- # 1. Patch embedding
339+ # 1. Patch & position embedding
340+ self .rope = WanRotaryPosEmbed (attention_head_dim , patch_size , rope_max_seq_len )
371341 self .patch_embedding = nn .Conv3d (in_channels , inner_dim , kernel_size = patch_size , stride = patch_size )
372342
373343 # 2. Condition embeddings
@@ -391,14 +361,6 @@ def __init__(
391361 self .proj_out = nn .Linear (inner_dim , out_channels * math .prod (patch_size ))
392362 self .scale_shift_table = nn .Parameter (torch .randn (1 , 2 , inner_dim ) / inner_dim ** 0.5 )
393363
394- # buffers (don't use register_buffer otherwise dtype will be changed in to())
395- assert attention_head_dim % 2 == 0
396- self .freqs = torch .cat ([
397- rope_params (1024 , attention_head_dim - 4 * (attention_head_dim // 6 )),
398- rope_params (1024 , 2 * (attention_head_dim // 6 )),
399- rope_params (1024 , 2 * (attention_head_dim // 6 ))
400- ], dim = 1 )
401-
402364 self .gradient_checkpointing = False
403365
404366 def forward (
@@ -409,14 +371,15 @@ def forward(
409371 encoder_hidden_states_image : Optional [torch .Tensor ] = None ,
410372 return_dict : bool = True ,
411373 ) -> Union [torch .Tensor , Dict [str , torch .Tensor ]]:
412- if self .freqs .device != hidden_states .device :
413- self .freqs = self .freqs .to (hidden_states .device )
374+ batch_size , num_channels , num_frames , height , width = hidden_states .shape
375+ p_t , p_h , p_w = self .config .patch_size
376+ post_patch_num_frames = num_frames // p_t
377+ post_patch_height = height // p_h
378+ post_patch_width = width // p_w
379+
380+ rotary_emb = self .rope (hidden_states )
414381
415382 hidden_states = self .patch_embedding (hidden_states )
416-
417- grid_sizes = torch .stack (
418- [torch .tensor (u .shape [1 :], dtype = torch .long ) for u in hidden_states ]
419- )
420383 hidden_states = hidden_states .flatten (2 ).transpose (1 , 2 )
421384
422385 temb , timestep_proj , encoder_hidden_states , encoder_hidden_states_image = self .condition_embedder (timestep , encoder_hidden_states , encoder_hidden_states_image )
@@ -428,49 +391,21 @@ def forward(
428391 # 4. Transformer blocks
429392 if torch .is_grad_enabled () and self .gradient_checkpointing :
430393 for block in self .blocks :
431- hidden_states = self ._gradient_checkpointing_func (
432- block ,
433- hidden_states ,
434- timestep_proj ,
435- encoder_hidden_states ,
436- grid_sizes ,
437- self .freqs ,
438- )
394+ hidden_states = self ._gradient_checkpointing_func (block , hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
439395 else :
440396 for block in self .blocks :
441- hidden_states = block (
442- hidden_states ,
443- timestep_proj ,
444- encoder_hidden_states ,
445- grid_sizes ,
446- self .freqs ,
447- )
448-
449- # Output projection
450- with torch .amp .autocast (str (hidden_states .device ), dtype = torch .float32 ):
451- temb = (self .scale_shift_table + temb .unsqueeze (1 )).chunk (2 , dim = 1 )
452- hidden_states = self .norm_out (hidden_states ) * (1 + temb [1 ]) + temb [0 ]
453- hidden_states = self .proj_out (hidden_states )
454-
455- hidden_states = hidden_states .type_as (encoder_hidden_states )
456-
457- # 5. Unpatchify
458- hidden_states = self .unpatchify (hidden_states , grid_sizes )
459- hidden_states = torch .stack (hidden_states )
397+ hidden_states = block (hidden_states , encoder_hidden_states , timestep_proj , rotary_emb )
460398
461- if not return_dict :
462- return (hidden_states ,)
399+ # 5. Output norm, projection & unpatchify
400+ shift , scale = (self .scale_shift_table + temb .unsqueeze (1 )).chunk (2 , dim = 1 )
401+ hidden_states = (self .norm_out (hidden_states ) * (1 + scale ) + shift ).type_as (hidden_states )
402+ hidden_states = self .proj_out (hidden_states )
463403
464- return Transformer2DModelOutput (sample = hidden_states )
404+ hidden_states = hidden_states .reshape (batch_size , post_patch_num_frames , post_patch_height , post_patch_width , p_t , p_h , p_w , - 1 )
405+ hidden_states = hidden_states .permute (0 , 7 , 1 , 4 , 2 , 5 , 3 , 6 )
406+ output = hidden_states .flatten (6 , 7 ).flatten (4 , 5 ).flatten (2 , 3 )
465407
408+ if not return_dict :
409+ return (output ,)
466410
467- def unpatchify (self , hidden_states , grid_sizes ):
468- c = self .out_channels
469- out = []
470- for u , v in zip (hidden_states , grid_sizes .tolist ()):
471- u = u [:math .prod (v )].view (* v , * self .patch_size , c )
472- u = torch .einsum ('fhwpqrc->cfphqwr' , u )
473- u = u .reshape (c , * [i * j for i , j in zip (v , self .patch_size )])
474- out .append (u )
475- return out
476-
411+ return Transformer2DModelOutput (sample = output )
0 commit comments