Skip to content

Commit ea6fde6

Browse files
authored
sequence parallel + FSDP (#29)
* sequence parallel * fix long_context_attention * fsdp * tp * fix
1 parent 3888f1d commit ea6fde6

File tree

12 files changed

+549
-191
lines changed

12 files changed

+549
-191
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
repos:
22
- repo: https://github.com/astral-sh/ruff-pre-commit
33
# Ruff version.
4-
rev: v0.9.10
4+
rev: v0.11.5
55
hooks:
66
# Run the linter.
77
- id: ruff

diffsynth_engine/models/basic/attention.py

Lines changed: 75 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import torch.nn as nn
33
from einops import rearrange
44
from typing import Optional
5+
from yunchang import LongContextAttention
6+
from yunchang.kernels import AttnType
7+
58
from diffsynth_engine.utils import logging
69
from diffsynth_engine.utils.flag import (
710
FLASH_ATTN_3_AVAILABLE,
@@ -12,12 +15,15 @@
1215
SPARGE_ATTN_AVAILABLE,
1316
)
1417

18+
logger = logging.get_logger(__name__)
19+
20+
1521
if FLASH_ATTN_3_AVAILABLE:
1622
from flash_attn_interface import flash_attn_func as flash_attn3
1723
if FLASH_ATTN_2_AVAILABLE:
1824
from flash_attn import flash_attn_func as flash_attn2
1925
if XFORMERS_AVAILABLE:
20-
import xformers.ops.memory_efficient_attention as xformers_attn
26+
from xformers.ops import memory_efficient_attention as xformers_attn
2127
if SDPA_AVAILABLE:
2228

2329
def sdpa_attn(q, k, v, attn_mask=None, scale=None):
@@ -50,20 +56,28 @@ def sparge_attn(self, q, k, v, attn_mask=None, scale=None):
5056
return out.transpose(1, 2)
5157

5258

53-
logger = logging.get_logger(__name__)
54-
55-
56-
def eager_attn(query, key, value, attn_mask=None, scale=None):
57-
scale = 1 / query.shape[-1] ** 0.5 if scale is None else scale
58-
query = query * scale
59-
attn = torch.matmul(query, key.transpose(-2, -1))
59+
def eager_attn(q, k, v, attn_mask=None, scale=None):
60+
q = q.transpose(1, 2)
61+
k = k.transpose(1, 2)
62+
v = v.transpose(1, 2)
63+
scale = 1 / q.shape[-1] ** 0.5 if scale is None else scale
64+
q = q * scale
65+
attn = torch.matmul(q, k.transpose(-2, -1))
6066
if attn_mask is not None:
6167
attn = attn + attn_mask
6268
attn = attn.softmax(-1)
63-
return attn @ value
64-
65-
66-
def attention(q, k, v, attn_mask=None, attn_impl: Optional[str] = None, scale: Optional[float] = None):
69+
out = attn @ v
70+
return out.transpose(1, 2)
71+
72+
73+
def attention(
74+
q,
75+
k,
76+
v,
77+
attn_impl: Optional[str] = None,
78+
attn_mask: Optional[torch.Tensor] = None,
79+
scale: Optional[float] = None,
80+
):
6781
"""
6882
q: [B, Lq, Nq, C1]
6983
k: [B, Lk, Nk, C1]
@@ -152,3 +166,52 @@ def forward(
152166
out = attention(q, k, v, attn_mask=attn_mask, attn_impl=self.attn_impl, scale=self.scale)
153167
out = rearrange(out, "b s n d -> b s (n d)", n=self.num_heads)
154168
return self.to_out(out)
169+
170+
171+
def long_context_attention(
172+
q,
173+
k,
174+
v,
175+
attn_impl: Optional[str] = None,
176+
attn_mask: Optional[torch.Tensor] = None,
177+
scale: Optional[float] = None,
178+
):
179+
"""
180+
q: [B, Lq, Nq, C1]
181+
k: [B, Lk, Nk, C1]
182+
v: [B, Lk, Nk, C2]
183+
"""
184+
assert attn_impl in [
185+
None,
186+
"auto",
187+
"eager",
188+
"flash_attn_2",
189+
"flash_attn_3",
190+
"xformers",
191+
"sdpa",
192+
"sage_attn",
193+
"sparge_attn",
194+
]
195+
if attn_impl is None or attn_impl == "auto":
196+
if FLASH_ATTN_3_AVAILABLE:
197+
attn_func = LongContextAttention(attn_type=AttnType.FA3)
198+
elif FLASH_ATTN_2_AVAILABLE:
199+
attn_func = LongContextAttention(attn_type=AttnType.FA)
200+
elif SDPA_AVAILABLE:
201+
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
202+
else:
203+
raise ValueError("No available long context attention implementation")
204+
else:
205+
if attn_impl == "flash_attn_3":
206+
attn_func = LongContextAttention(attn_type=AttnType.FA3)
207+
elif attn_impl == "flash_attn_2":
208+
attn_func = LongContextAttention(attn_type=AttnType.FA)
209+
elif attn_impl == "sdpa":
210+
attn_func = LongContextAttention(attn_type=AttnType.TORCH)
211+
elif attn_impl == "sage_attn":
212+
attn_func = LongContextAttention(attn_type=AttnType.SAGE_FP8)
213+
elif attn_impl == "sparge_attn":
214+
attn_func = LongContextAttention(attn_type=AttnType.SPARSE_SAGE)
215+
else:
216+
raise ValueError(f"Invalid long context attention implementation: {attn_impl}")
217+
return attn_func(q, k, v, softmax_scale=scale)

diffsynth_engine/models/wan/wan_dit.py

Lines changed: 88 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@
22
import json
33
import torch
44
import torch.nn as nn
5+
import torch.distributed as dist
56
from typing import Tuple, Optional
67
from einops import rearrange
78

89
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
10+
from diffsynth_engine.models.basic.attention import attention, long_context_attention
911
from diffsynth_engine.models.utils import no_init_weights
1012
from diffsynth_engine.utils.constants import (
1113
WAN_DIT_1_3B_T2V_CONFIG_FILE,
1214
WAN_DIT_14B_I2V_CONFIG_FILE,
1315
WAN_DIT_14B_T2V_CONFIG_FILE,
1416
)
15-
1617
from diffsynth_engine.utils.gguf import gguf_inference
17-
from diffsynth_engine.models.basic.attention import attention
18+
from diffsynth_engine.utils.parallel import (
19+
get_sp_group,
20+
get_sp_world_size,
21+
get_sp_rank,
22+
)
1823

1924

2025
def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
@@ -99,7 +104,21 @@ def forward(self, x, freqs):
99104
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
100105
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
101106
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
102-
x = attention(q=rope_apply(q, freqs), k=rope_apply(k, freqs), v=v, attn_impl=self.attn_impl).flatten(2)
107+
if getattr(self, "use_usp", False):
108+
x = long_context_attention(
109+
q=rope_apply(q, freqs),
110+
k=rope_apply(k, freqs),
111+
v=v,
112+
attn_impl=self.attn_impl,
113+
)
114+
else:
115+
x = attention(
116+
q=rope_apply(q, freqs),
117+
k=rope_apply(k, freqs),
118+
v=v,
119+
attn_impl=self.attn_impl,
120+
)
121+
x = x.flatten(2)
103122
return self.o(x)
104123

105124

@@ -259,6 +278,7 @@ def __init__(
259278
num_layers: int,
260279
has_image_input: bool,
261280
attn_impl: Optional[str] = None,
281+
use_usp: bool = False,
262282
device: str = "cpu",
263283
dtype: torch.dtype = torch.bfloat16,
264284
):
@@ -301,6 +321,11 @@ def __init__(
301321
if has_image_input:
302322
self.img_emb = MLP(1280, dim, device=device, dtype=dtype) # clip_feature_dim = 1280
303323

324+
if use_usp:
325+
setattr(self, "use_usp", True)
326+
for block in self.blocks:
327+
setattr(block.self_attn, "use_usp", True)
328+
304329
def patchify(self, x: torch.Tensor):
305330
x = self.patch_embedding(x) # b c f h w -> b 4c f h/2 w/2
306331
grid_size = x.shape[2:]
@@ -348,15 +373,34 @@ def forward(
348373
.reshape(f * h * w, 1, -1)
349374
.to(x.device)
350375
)
376+
if getattr(self, "use_usp", False):
377+
s, p = x.size(1), get_sp_world_size() # (sequence_length, parallelism)
378+
split_size = [s // p + 1 if i < s % p else s // p for i in range(p)]
379+
x = torch.split(x, split_size, dim=1)[get_sp_rank()]
380+
freqs = torch.split(freqs, split_size, dim=0)[get_sp_rank()]
381+
351382
for block in self.blocks:
352383
x = block(x, context, t_mod, freqs)
353384
x = self.head(x, t)
385+
386+
if getattr(self, "use_usp", False):
387+
b, d = x.size(0), x.size(2) # (batch_size, out_dim)
388+
xs = [torch.zeros((b, s, d), dtype=x.dtype, device=x.device) for s in split_size]
389+
dist.all_gather(xs, x, group=get_sp_group())
390+
x = torch.concat(xs, dim=1)
354391
x = self.unpatchify(x, (f, h, w))
355392
return x
356393

357394
@classmethod
358395
def from_state_dict(
359-
cls, state_dict, device, dtype, model_type="1.3b-t2v", attn_impl: Optional[str] = None, assign=True
396+
cls,
397+
state_dict,
398+
device,
399+
dtype,
400+
model_type="1.3b-t2v",
401+
attn_impl: Optional[str] = None,
402+
use_usp=False,
403+
assign=True,
360404
):
361405
if model_type == "1.3b-t2v":
362406
config = json.load(open(WAN_DIT_1_3B_T2V_CONFIG_FILE, "r"))
@@ -367,7 +411,9 @@ def from_state_dict(
367411
else:
368412
raise ValueError(f"Unsupported model type: {model_type}")
369413
with no_init_weights():
370-
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_impl=attn_impl)
414+
model = torch.nn.utils.skip_init(
415+
cls, **config, device=device, dtype=dtype, attn_impl=attn_impl, use_usp=use_usp
416+
)
371417
model = model.requires_grad_(False)
372418
model.load_state_dict(state_dict, assign=assign)
373419
model.to(device=device, dtype=dtype)
@@ -377,7 +423,7 @@ def get_tp_plan(self):
377423
from torch.distributed.tensor.parallel import (
378424
ColwiseParallel,
379425
RowwiseParallel,
380-
SequenceParallel,
426+
PrepareModuleInput,
381427
PrepareModuleOutput,
382428
)
383429
from torch.distributed.tensor import Replicate, Shard
@@ -388,45 +434,64 @@ def get_tp_plan(self):
388434
"time_embedding.0": ColwiseParallel(),
389435
"time_embedding.2": RowwiseParallel(),
390436
"time_projection.1": ColwiseParallel(output_layouts=Replicate()),
437+
"blocks.0": PrepareModuleInput(
438+
input_layouts=(Replicate(), None, None, None),
439+
desired_input_layouts=(Shard(1), None, None, None), # sequence parallel
440+
use_local_output=True,
441+
),
442+
"head": PrepareModuleOutput(
443+
output_layouts=Shard(1),
444+
desired_output_layouts=Replicate(),
445+
use_local_output=True,
446+
),
391447
}
392448
for idx in range(len(self.blocks)):
393449
tp_plan.update(
394450
{
395-
f"blocks.{idx}.norm1": SequenceParallel(use_local_output=True),
396-
f"blocks.{idx}.norm2": SequenceParallel(use_local_output=True),
397-
f"blocks.{idx}.norm3": SequenceParallel(use_local_output=True),
398-
f"blocks.{idx}.ffn.0": ColwiseParallel(),
399-
f"blocks.{idx}.ffn.2": RowwiseParallel(),
400-
f"blocks.{idx}.self_attn.q": ColwiseParallel(output_layouts=Replicate()),
401-
f"blocks.{idx}.self_attn.k": ColwiseParallel(output_layouts=Replicate()),
451+
f"blocks.{idx}.self_attn": PrepareModuleInput(
452+
input_layouts=(Shard(1), None),
453+
desired_input_layouts=(Replicate(), None),
454+
),
455+
f"blocks.{idx}.self_attn.q": ColwiseParallel(output_layouts=Shard(1)),
456+
f"blocks.{idx}.self_attn.k": ColwiseParallel(output_layouts=Shard(1)),
402457
f"blocks.{idx}.self_attn.v": ColwiseParallel(),
403-
f"blocks.{idx}.self_attn.o": RowwiseParallel(),
458+
f"blocks.{idx}.self_attn.o": RowwiseParallel(output_layouts=Shard(1)),
404459
f"blocks.{idx}.self_attn.norm_q": PrepareModuleOutput(
405-
output_layouts=Replicate(),
460+
output_layouts=Shard(1),
406461
desired_output_layouts=Shard(-1),
407462
),
408463
f"blocks.{idx}.self_attn.norm_k": PrepareModuleOutput(
409-
output_layouts=Replicate(),
464+
output_layouts=Shard(1),
410465
desired_output_layouts=Shard(-1),
411466
),
412-
f"blocks.{idx}.cross_attn.q": ColwiseParallel(output_layouts=Replicate()),
413-
f"blocks.{idx}.cross_attn.k": ColwiseParallel(output_layouts=Replicate()),
467+
f"blocks.{idx}.cross_attn": PrepareModuleInput(
468+
input_layouts=(Shard(1), None),
469+
desired_input_layouts=(Replicate(), None),
470+
),
471+
f"blocks.{idx}.cross_attn.q": ColwiseParallel(output_layouts=Shard(1)),
472+
f"blocks.{idx}.cross_attn.k": ColwiseParallel(output_layouts=Shard(1)),
414473
f"blocks.{idx}.cross_attn.v": ColwiseParallel(),
415-
f"blocks.{idx}.cross_attn.o": RowwiseParallel(),
474+
f"blocks.{idx}.cross_attn.o": RowwiseParallel(output_layouts=Shard(1)),
416475
f"blocks.{idx}.cross_attn.norm_q": PrepareModuleOutput(
417-
output_layouts=Replicate(),
476+
output_layouts=Shard(1),
418477
desired_output_layouts=Shard(-1),
419478
),
420479
f"blocks.{idx}.cross_attn.norm_k": PrepareModuleOutput(
421-
output_layouts=Replicate(),
480+
output_layouts=Shard(1),
422481
desired_output_layouts=Shard(-1),
423482
),
424-
f"blocks.{idx}.cross_attn.k_img": ColwiseParallel(output_layouts=Replicate()),
483+
f"blocks.{idx}.cross_attn.k_img": ColwiseParallel(output_layouts=Shard(1)),
425484
f"blocks.{idx}.cross_attn.v_img": ColwiseParallel(),
426485
f"blocks.{idx}.cross_attn.norm_k_img": PrepareModuleOutput(
427-
output_layouts=Replicate(),
486+
output_layouts=Shard(1),
428487
desired_output_layouts=Shard(-1),
429488
),
489+
f"blocks.{idx}.ffn": PrepareModuleInput(
490+
input_layouts=(Shard(1),),
491+
desired_input_layouts=(Replicate(),),
492+
),
493+
f"blocks.{idx}.ffn.0": ColwiseParallel(),
494+
f"blocks.{idx}.ffn.2": RowwiseParallel(output_layouts=Shard(1)),
430495
}
431496
)
432497
return tp_plan

diffsynth_engine/pipelines/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,12 +205,12 @@ def load_models_to_device(self, load_model_names: List[str] | None = None):
205205
for model_name in self.model_names:
206206
if model_name not in load_model_names:
207207
model = getattr(self, model_name)
208-
if model is not None and next(model.parameters()).device != "cpu":
208+
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != "cpu":
209209
model.to("cpu")
210210
# load the needed models to device
211211
for model_name in load_model_names:
212212
model = getattr(self, model_name)
213-
if model is not None and next(model.parameters()).device != self.device:
213+
if model is not None and (p := next(model.parameters(), None)) is not None and p.device != self.device:
214214
model.to(self.device)
215215
# fresh the cuda cache
216216
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)