Skip to content

Commit e3cf908

Browse files
authored
Fix circular dependence (#205)
* add update weights for wan pipeline * fix circular dependent when vsa enable * fix * fix circular dependence * up
1 parent ae4faeb commit e3cf908

File tree

4 files changed

+212
-111
lines changed

4 files changed

+212
-111
lines changed

diffsynth_engine/models/basic/video_sparse_attention.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
import functools
44

55
from diffsynth_engine.utils.flag import VIDEO_SPARSE_ATTN_AVAILABLE
6-
from diffsynth_engine.utils.parallel import get_sp_ulysses_group, get_sp_ring_world_size
6+
from diffsynth_engine.utils.process_group import get_sp_ulysses_group, get_sp_ring_world_size
77

8+
9+
vsa_core = None
810
if VIDEO_SPARSE_ATTN_AVAILABLE:
9-
from vsa import video_sparse_attn as vsa_core
11+
try:
12+
from vsa import video_sparse_attn as vsa_core
13+
except Exception:
14+
vsa_core = None
1015

1116
VSA_TILE_SIZE = (4, 4, 4)
1217

@@ -171,6 +176,12 @@ def video_sparse_attn(
171176
variable_block_sizes: torch.LongTensor,
172177
non_pad_index: torch.LongTensor,
173178
):
179+
if vsa_core is None:
180+
raise RuntimeError(
181+
"Video sparse attention (VSA) is not available. "
182+
"Please install the 'vsa' package and ensure all its dependencies (including pytest) are installed."
183+
)
184+
174185
q = tile(q, num_tiles, tile_partition_indices, non_pad_index)
175186
k = tile(k, num_tiles, tile_partition_indices, non_pad_index)
176187
v = tile(v, num_tiles, tile_partition_indices, non_pad_index)
@@ -212,7 +223,8 @@ def distributed_video_sparse_attn(
212223
):
213224
from yunchang.comm.all_to_all import SeqAllToAll4D
214225

215-
assert get_sp_ring_world_size() == 1, "distributed video sparse attention requires ring degree to be 1"
226+
ring_world_size = get_sp_ring_world_size()
227+
assert ring_world_size == 1, "distributed video sparse attention requires ring degree to be 1"
216228
sp_ulysses_group = get_sp_ulysses_group()
217229

218230
q = SeqAllToAll4D.apply(sp_ulysses_group, q, scatter_idx, gather_idx)

diffsynth_engine/pipelines/wan_video.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,7 @@ def has_any_key(*xs):
650650
dit_type = "wan2.2-i2v-a14b"
651651
elif model_state_dict["high_noise_model"]["patch_embedding.weight"].shape[1] == 16:
652652
dit_type = "wan2.2-t2v-a14b"
653-
elif model_state_dict["patch_embedding.weight"].shape[1] == 48:
653+
elif has_any_key("patch_embedding.weight") and model_state_dict["patch_embedding.weight"].shape[1] == 48:
654654
dit_type = "wan2.2-ti2v-5b"
655655
elif has_any_key("img_emb.emb_pos", "condition_embedder.image_embedder.pos_embed"):
656656
dit_type = "wan2.1-flf2v-14b"
@@ -680,6 +680,30 @@ def has_any_key(*xs):
680680
if config.attn_params is None:
681681
config.attn_params = VideoSparseAttentionParams(sparsity=0.9)
682682

683+
def update_weights(self, state_dicts: WanStateDicts) -> None:
684+
is_dual_model_state_dict = (isinstance(state_dicts.model, dict) and
685+
("high_noise_model" in state_dicts.model or "low_noise_model" in state_dicts.model))
686+
is_dual_model_pipeline = self.dit2 is not None
687+
688+
if is_dual_model_state_dict != is_dual_model_pipeline:
689+
raise ValueError(
690+
f"Model structure mismatch: pipeline has {'dual' if is_dual_model_pipeline else 'single'} model "
691+
f"but state_dict is for {'dual' if is_dual_model_state_dict else 'single'} model. "
692+
f"Cannot update weights between WAN 2.1 (single model) and WAN 2.2 (dual model)."
693+
)
694+
695+
if is_dual_model_state_dict:
696+
if "high_noise_model" in state_dicts.model:
697+
self.update_component(self.dit, state_dicts.model["high_noise_model"], self.config.device, self.config.model_dtype)
698+
if "low_noise_model" in state_dicts.model:
699+
self.update_component(self.dit2, state_dicts.model["low_noise_model"], self.config.device, self.config.model_dtype)
700+
else:
701+
self.update_component(self.dit, state_dicts.model, self.config.device, self.config.model_dtype)
702+
703+
self.update_component(self.text_encoder, state_dicts.t5, self.config.device, self.config.t5_dtype)
704+
self.update_component(self.vae, state_dicts.vae, self.config.device, self.config.vae_dtype)
705+
self.update_component(self.image_encoder, state_dicts.image_encoder, self.config.device, self.config.image_encoder_dtype)
706+
683707
def compile(self):
684708
self.dit.compile_repeated_blocks()
685709
if self.dit2 is not None:

diffsynth_engine/utils/parallel.py

Lines changed: 23 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -21,117 +21,33 @@
2121
import diffsynth_engine.models.basic.attention as attention_ops
2222
from diffsynth_engine.utils.platform import empty_cache
2323
from diffsynth_engine.utils import logging
24+
from diffsynth_engine.utils.process_group import (
25+
PROCESS_GROUP,
26+
get_cfg_group,
27+
get_cfg_world_size,
28+
get_cfg_rank,
29+
get_cfg_ranks,
30+
get_sp_group,
31+
get_sp_world_size,
32+
get_sp_rank,
33+
get_sp_ranks,
34+
get_sp_ulysses_group,
35+
get_sp_ulysses_world_size,
36+
get_sp_ulysses_rank,
37+
get_sp_ulysses_ranks,
38+
get_sp_ring_group,
39+
get_sp_ring_world_size,
40+
get_sp_ring_rank,
41+
get_sp_ring_ranks,
42+
get_tp_group,
43+
get_tp_world_size,
44+
get_tp_rank,
45+
get_tp_ranks,
46+
)
2447

2548
logger = logging.get_logger(__name__)
2649

2750

28-
class Singleton:
29-
_instance = None
30-
31-
def __new__(cls, *args, **kwargs):
32-
if not cls._instance:
33-
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
34-
return cls._instance
35-
36-
37-
class ProcessGroupSingleton(Singleton):
38-
def __init__(self):
39-
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
40-
self.SP_GROUP: Optional[dist.ProcessGroup] = None
41-
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
42-
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
43-
self.TP_GROUP: Optional[dist.ProcessGroup] = None
44-
45-
self.CFG_RANKS: List[int] = []
46-
self.SP_RANKS: List[int] = []
47-
self.SP_ULYSSUES_RANKS: List[int] = []
48-
self.SP_RING_RANKS: List[int] = []
49-
self.TP_RANKS: List[int] = []
50-
51-
52-
PROCESS_GROUP = ProcessGroupSingleton()
53-
54-
55-
def get_cfg_group():
56-
return PROCESS_GROUP.CFG_GROUP
57-
58-
59-
def get_cfg_world_size():
60-
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
61-
62-
63-
def get_cfg_rank():
64-
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
65-
66-
67-
def get_cfg_ranks():
68-
return PROCESS_GROUP.CFG_RANKS
69-
70-
71-
def get_sp_group():
72-
return PROCESS_GROUP.SP_GROUP
73-
74-
75-
def get_sp_world_size():
76-
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
77-
78-
79-
def get_sp_rank():
80-
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
81-
82-
83-
def get_sp_ranks():
84-
return PROCESS_GROUP.SP_RANKS
85-
86-
87-
def get_sp_ulysses_group():
88-
return PROCESS_GROUP.SP_ULYSSUES_GROUP
89-
90-
91-
def get_sp_ulysses_world_size():
92-
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
93-
94-
95-
def get_sp_ulysses_rank():
96-
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
97-
98-
99-
def get_sp_ulysses_ranks():
100-
return PROCESS_GROUP.SP_ULYSSUES_RANKS
101-
102-
103-
def get_sp_ring_group():
104-
return PROCESS_GROUP.SP_RING_GROUP
105-
106-
107-
def get_sp_ring_world_size():
108-
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
109-
110-
111-
def get_sp_ring_rank():
112-
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
113-
114-
115-
def get_sp_ring_ranks():
116-
return PROCESS_GROUP.SP_RING_RANKS
117-
118-
119-
def get_tp_group():
120-
return PROCESS_GROUP.TP_GROUP
121-
122-
123-
def get_tp_world_size():
124-
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
125-
126-
127-
def get_tp_rank():
128-
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
129-
130-
131-
def get_tp_ranks():
132-
return PROCESS_GROUP.TP_RANKS
133-
134-
13551
def init_parallel_pgs(
13652
cfg_degree: int = 1,
13753
sp_ulysses_degree: int = 1,
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
Process group management for distributed training.
3+
4+
This module provides singleton-based process group management for distributed training,
5+
including support for CFG parallelism, sequence parallelism (Ulysses + Ring), and tensor parallelism.
6+
"""
7+
8+
import torch.distributed as dist
9+
from typing import Optional, List
10+
11+
12+
class Singleton:
13+
_instance = None
14+
15+
def __new__(cls, *args, **kwargs):
16+
if not cls._instance:
17+
cls._instance = super(Singleton, cls).__new__(cls, *args, **kwargs)
18+
return cls._instance
19+
20+
21+
class ProcessGroupSingleton(Singleton):
22+
def __init__(self):
23+
if not hasattr(self, 'initialized'):
24+
self.CFG_GROUP: Optional[dist.ProcessGroup] = None
25+
self.SP_GROUP: Optional[dist.ProcessGroup] = None
26+
self.SP_ULYSSUES_GROUP: Optional[dist.ProcessGroup] = None
27+
self.SP_RING_GROUP: Optional[dist.ProcessGroup] = None
28+
self.TP_GROUP: Optional[dist.ProcessGroup] = None
29+
30+
self.CFG_RANKS: List[int] = []
31+
self.SP_RANKS: List[int] = []
32+
self.SP_ULYSSUES_RANKS: List[int] = []
33+
self.SP_RING_RANKS: List[int] = []
34+
self.TP_RANKS: List[int] = []
35+
36+
self.initialized = True
37+
38+
39+
PROCESS_GROUP = ProcessGroupSingleton()
40+
41+
42+
# CFG parallel group functions
43+
def get_cfg_group():
44+
return PROCESS_GROUP.CFG_GROUP
45+
46+
47+
def get_cfg_world_size():
48+
return PROCESS_GROUP.CFG_GROUP.size() if PROCESS_GROUP.CFG_GROUP is not None else 1
49+
50+
51+
def get_cfg_rank():
52+
return PROCESS_GROUP.CFG_GROUP.rank() if PROCESS_GROUP.CFG_GROUP is not None else 0
53+
54+
55+
def get_cfg_ranks():
56+
return PROCESS_GROUP.CFG_RANKS
57+
58+
59+
# Sequence parallel group functions
60+
def get_sp_group():
61+
return PROCESS_GROUP.SP_GROUP
62+
63+
64+
def get_sp_world_size():
65+
return PROCESS_GROUP.SP_GROUP.size() if PROCESS_GROUP.SP_GROUP is not None else 1
66+
67+
68+
def get_sp_rank():
69+
return PROCESS_GROUP.SP_GROUP.rank() if PROCESS_GROUP.SP_GROUP is not None else 0
70+
71+
72+
def get_sp_ranks():
73+
return PROCESS_GROUP.SP_RANKS
74+
75+
76+
# Sequence parallel Ulysses group functions
77+
def get_sp_ulysses_group():
78+
return PROCESS_GROUP.SP_ULYSSUES_GROUP
79+
80+
81+
def get_sp_ulysses_world_size():
82+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.size() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 1
83+
84+
85+
def get_sp_ulysses_rank():
86+
return PROCESS_GROUP.SP_ULYSSUES_GROUP.rank() if PROCESS_GROUP.SP_ULYSSUES_GROUP is not None else 0
87+
88+
89+
def get_sp_ulysses_ranks():
90+
return PROCESS_GROUP.SP_ULYSSUES_RANKS
91+
92+
93+
# Sequence parallel Ring group functions
94+
def get_sp_ring_group():
95+
return PROCESS_GROUP.SP_RING_GROUP
96+
97+
98+
def get_sp_ring_world_size():
99+
return PROCESS_GROUP.SP_RING_GROUP.size() if PROCESS_GROUP.SP_RING_GROUP is not None else 1
100+
101+
102+
def get_sp_ring_rank():
103+
return PROCESS_GROUP.SP_RING_GROUP.rank() if PROCESS_GROUP.SP_RING_GROUP is not None else 0
104+
105+
106+
def get_sp_ring_ranks():
107+
return PROCESS_GROUP.SP_RING_RANKS
108+
109+
110+
# Tensor parallel group functions
111+
def get_tp_group():
112+
return PROCESS_GROUP.TP_GROUP
113+
114+
115+
def get_tp_world_size():
116+
return PROCESS_GROUP.TP_GROUP.size() if PROCESS_GROUP.TP_GROUP is not None else 1
117+
118+
119+
def get_tp_rank():
120+
return PROCESS_GROUP.TP_GROUP.rank() if PROCESS_GROUP.TP_GROUP is not None else 0
121+
122+
123+
def get_tp_ranks():
124+
return PROCESS_GROUP.TP_RANKS
125+
126+
127+
__all__ = [
128+
"PROCESS_GROUP",
129+
"get_cfg_group",
130+
"get_cfg_world_size",
131+
"get_cfg_rank",
132+
"get_cfg_ranks",
133+
"get_sp_group",
134+
"get_sp_world_size",
135+
"get_sp_rank",
136+
"get_sp_ranks",
137+
"get_sp_ulysses_group",
138+
"get_sp_ulysses_world_size",
139+
"get_sp_ulysses_rank",
140+
"get_sp_ulysses_ranks",
141+
"get_sp_ring_group",
142+
"get_sp_ring_world_size",
143+
"get_sp_ring_rank",
144+
"get_sp_ring_ranks",
145+
"get_tp_group",
146+
"get_tp_world_size",
147+
"get_tp_rank",
148+
"get_tp_ranks",
149+
]

0 commit comments

Comments
 (0)