22import json
33import torch
44import torch .nn as nn
5+ import torch .distributed as dist
56from typing import Tuple , Optional
67from einops import rearrange
78
89from diffsynth_engine .models .base import StateDictConverter , PreTrainedModel
10+ from diffsynth_engine .models .basic .attention import attention , long_context_attention
911from diffsynth_engine .models .utils import no_init_weights
1012from diffsynth_engine .utils .constants import (
1113 WAN_DIT_1_3B_T2V_CONFIG_FILE ,
1214 WAN_DIT_14B_I2V_CONFIG_FILE ,
1315 WAN_DIT_14B_T2V_CONFIG_FILE ,
1416)
15-
1617from diffsynth_engine .utils .gguf import gguf_inference
17- from diffsynth_engine .models .basic .attention import attention
18+ from diffsynth_engine .utils .parallel import (
19+ get_sp_group ,
20+ get_sp_world_size ,
21+ get_sp_rank ,
22+ )
1823
1924
2025def modulate (x : torch .Tensor , shift : torch .Tensor , scale : torch .Tensor ):
@@ -99,7 +104,21 @@ def forward(self, x, freqs):
99104 q = rearrange (q , "b s (n d) -> b s n d" , n = num_heads )
100105 k = rearrange (k , "b s (n d) -> b s n d" , n = num_heads )
101106 v = rearrange (v , "b s (n d) -> b s n d" , n = num_heads )
102- x = attention (q = rope_apply (q , freqs ), k = rope_apply (k , freqs ), v = v , attn_impl = self .attn_impl ).flatten (2 )
107+ if getattr (self , "use_usp" , False ):
108+ x = long_context_attention (
109+ q = rope_apply (q , freqs ),
110+ k = rope_apply (k , freqs ),
111+ v = v ,
112+ attn_impl = self .attn_impl ,
113+ )
114+ else :
115+ x = attention (
116+ q = rope_apply (q , freqs ),
117+ k = rope_apply (k , freqs ),
118+ v = v ,
119+ attn_impl = self .attn_impl ,
120+ )
121+ x = x .flatten (2 )
103122 return self .o (x )
104123
105124
@@ -259,6 +278,7 @@ def __init__(
259278 num_layers : int ,
260279 has_image_input : bool ,
261280 attn_impl : Optional [str ] = None ,
281+ use_usp : bool = False ,
262282 device : str = "cpu" ,
263283 dtype : torch .dtype = torch .bfloat16 ,
264284 ):
@@ -301,6 +321,11 @@ def __init__(
301321 if has_image_input :
302322 self .img_emb = MLP (1280 , dim , device = device , dtype = dtype ) # clip_feature_dim = 1280
303323
324+ if use_usp :
325+ setattr (self , "use_usp" , True )
326+ for block in self .blocks :
327+ setattr (block .self_attn , "use_usp" , True )
328+
304329 def patchify (self , x : torch .Tensor ):
305330 x = self .patch_embedding (x ) # b c f h w -> b 4c f h/2 w/2
306331 grid_size = x .shape [2 :]
@@ -348,15 +373,34 @@ def forward(
348373 .reshape (f * h * w , 1 , - 1 )
349374 .to (x .device )
350375 )
376+ if getattr (self , "use_usp" , False ):
377+ s , p = x .size (1 ), get_sp_world_size () # (sequence_length, parallelism)
378+ split_size = [s // p + 1 if i < s % p else s // p for i in range (p )]
379+ x = torch .split (x , split_size , dim = 1 )[get_sp_rank ()]
380+ freqs = torch .split (freqs , split_size , dim = 0 )[get_sp_rank ()]
381+
351382 for block in self .blocks :
352383 x = block (x , context , t_mod , freqs )
353384 x = self .head (x , t )
385+
386+ if getattr (self , "use_usp" , False ):
387+ b , d = x .size (0 ), x .size (2 ) # (batch_size, out_dim)
388+ xs = [torch .zeros ((b , s , d ), dtype = x .dtype , device = x .device ) for s in split_size ]
389+ dist .all_gather (xs , x , group = get_sp_group ())
390+ x = torch .concat (xs , dim = 1 )
354391 x = self .unpatchify (x , (f , h , w ))
355392 return x
356393
357394 @classmethod
358395 def from_state_dict (
359- cls , state_dict , device , dtype , model_type = "1.3b-t2v" , attn_impl : Optional [str ] = None , assign = True
396+ cls ,
397+ state_dict ,
398+ device ,
399+ dtype ,
400+ model_type = "1.3b-t2v" ,
401+ attn_impl : Optional [str ] = None ,
402+ use_usp = False ,
403+ assign = True ,
360404 ):
361405 if model_type == "1.3b-t2v" :
362406 config = json .load (open (WAN_DIT_1_3B_T2V_CONFIG_FILE , "r" ))
@@ -367,7 +411,9 @@ def from_state_dict(
367411 else :
368412 raise ValueError (f"Unsupported model type: { model_type } " )
369413 with no_init_weights ():
370- model = torch .nn .utils .skip_init (cls , ** config , device = device , dtype = dtype , attn_impl = attn_impl )
414+ model = torch .nn .utils .skip_init (
415+ cls , ** config , device = device , dtype = dtype , attn_impl = attn_impl , use_usp = use_usp
416+ )
371417 model = model .requires_grad_ (False )
372418 model .load_state_dict (state_dict , assign = assign )
373419 model .to (device = device , dtype = dtype )
@@ -377,7 +423,7 @@ def get_tp_plan(self):
377423 from torch .distributed .tensor .parallel import (
378424 ColwiseParallel ,
379425 RowwiseParallel ,
380- SequenceParallel ,
426+ PrepareModuleInput ,
381427 PrepareModuleOutput ,
382428 )
383429 from torch .distributed .tensor import Replicate , Shard
@@ -388,45 +434,64 @@ def get_tp_plan(self):
388434 "time_embedding.0" : ColwiseParallel (),
389435 "time_embedding.2" : RowwiseParallel (),
390436 "time_projection.1" : ColwiseParallel (output_layouts = Replicate ()),
437+ "blocks.0" : PrepareModuleInput (
438+ input_layouts = (Replicate (), None , None , None ),
439+ desired_input_layouts = (Shard (1 ), None , None , None ), # sequence parallel
440+ use_local_output = True ,
441+ ),
442+ "head" : PrepareModuleOutput (
443+ output_layouts = Shard (1 ),
444+ desired_output_layouts = Replicate (),
445+ use_local_output = True ,
446+ ),
391447 }
392448 for idx in range (len (self .blocks )):
393449 tp_plan .update (
394450 {
395- f"blocks.{ idx } .norm1" : SequenceParallel (use_local_output = True ),
396- f"blocks.{ idx } .norm2" : SequenceParallel (use_local_output = True ),
397- f"blocks.{ idx } .norm3" : SequenceParallel (use_local_output = True ),
398- f"blocks.{ idx } .ffn.0" : ColwiseParallel (),
399- f"blocks.{ idx } .ffn.2" : RowwiseParallel (),
400- f"blocks.{ idx } .self_attn.q" : ColwiseParallel (output_layouts = Replicate ()),
401- f"blocks.{ idx } .self_attn.k" : ColwiseParallel (output_layouts = Replicate ()),
451+ f"blocks.{ idx } .self_attn" : PrepareModuleInput (
452+ input_layouts = (Shard (1 ), None ),
453+ desired_input_layouts = (Replicate (), None ),
454+ ),
455+ f"blocks.{ idx } .self_attn.q" : ColwiseParallel (output_layouts = Shard (1 )),
456+ f"blocks.{ idx } .self_attn.k" : ColwiseParallel (output_layouts = Shard (1 )),
402457 f"blocks.{ idx } .self_attn.v" : ColwiseParallel (),
403- f"blocks.{ idx } .self_attn.o" : RowwiseParallel (),
458+ f"blocks.{ idx } .self_attn.o" : RowwiseParallel (output_layouts = Shard ( 1 ) ),
404459 f"blocks.{ idx } .self_attn.norm_q" : PrepareModuleOutput (
405- output_layouts = Replicate ( ),
460+ output_layouts = Shard ( 1 ),
406461 desired_output_layouts = Shard (- 1 ),
407462 ),
408463 f"blocks.{ idx } .self_attn.norm_k" : PrepareModuleOutput (
409- output_layouts = Replicate ( ),
464+ output_layouts = Shard ( 1 ),
410465 desired_output_layouts = Shard (- 1 ),
411466 ),
412- f"blocks.{ idx } .cross_attn.q" : ColwiseParallel (output_layouts = Replicate ()),
413- f"blocks.{ idx } .cross_attn.k" : ColwiseParallel (output_layouts = Replicate ()),
467+ f"blocks.{ idx } .cross_attn" : PrepareModuleInput (
468+ input_layouts = (Shard (1 ), None ),
469+ desired_input_layouts = (Replicate (), None ),
470+ ),
471+ f"blocks.{ idx } .cross_attn.q" : ColwiseParallel (output_layouts = Shard (1 )),
472+ f"blocks.{ idx } .cross_attn.k" : ColwiseParallel (output_layouts = Shard (1 )),
414473 f"blocks.{ idx } .cross_attn.v" : ColwiseParallel (),
415- f"blocks.{ idx } .cross_attn.o" : RowwiseParallel (),
474+ f"blocks.{ idx } .cross_attn.o" : RowwiseParallel (output_layouts = Shard ( 1 ) ),
416475 f"blocks.{ idx } .cross_attn.norm_q" : PrepareModuleOutput (
417- output_layouts = Replicate ( ),
476+ output_layouts = Shard ( 1 ),
418477 desired_output_layouts = Shard (- 1 ),
419478 ),
420479 f"blocks.{ idx } .cross_attn.norm_k" : PrepareModuleOutput (
421- output_layouts = Replicate ( ),
480+ output_layouts = Shard ( 1 ),
422481 desired_output_layouts = Shard (- 1 ),
423482 ),
424- f"blocks.{ idx } .cross_attn.k_img" : ColwiseParallel (output_layouts = Replicate ( )),
483+ f"blocks.{ idx } .cross_attn.k_img" : ColwiseParallel (output_layouts = Shard ( 1 )),
425484 f"blocks.{ idx } .cross_attn.v_img" : ColwiseParallel (),
426485 f"blocks.{ idx } .cross_attn.norm_k_img" : PrepareModuleOutput (
427- output_layouts = Replicate ( ),
486+ output_layouts = Shard ( 1 ),
428487 desired_output_layouts = Shard (- 1 ),
429488 ),
489+ f"blocks.{ idx } .ffn" : PrepareModuleInput (
490+ input_layouts = (Shard (1 ),),
491+ desired_input_layouts = (Replicate (),),
492+ ),
493+ f"blocks.{ idx } .ffn.0" : ColwiseParallel (),
494+ f"blocks.{ idx } .ffn.2" : RowwiseParallel (output_layouts = Shard (1 )),
430495 }
431496 )
432497 return tp_plan
0 commit comments