Skip to content

Commit b39dbd6

Browse files
authored
parallel wrapper for pipeline/model (#102)
* parallel wrapper for pipeline/model * fix error msg
1 parent bb38e70 commit b39dbd6

File tree

12 files changed

+250
-212
lines changed

12 files changed

+250
-212
lines changed

diffsynth_engine/models/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
1414

1515
class PreTrainedModel(nn.Module):
1616
converter = StateDictConverter()
17+
_supports_parallelization = False
1718

1819
def load_state_dict(self, state_dict: Dict[str, torch.Tensor], strict: bool = True, assign: bool = False):
1920
state_dict = self.converter.convert(state_dict)
@@ -55,6 +56,12 @@ def unload_loras(self):
5556
if isinstance(module, (LoRALinear, LoRAConv2d)):
5657
module.clear()
5758

59+
def get_tp_plan(self):
60+
raise NotImplementedError(f"{self.__class__.__name__} does not support TP")
61+
62+
def get_fsdp_modules(self):
63+
raise NotImplementedError(f"{self.__class__.__name__} does not support FSDP")
64+
5865

5966
def split_suffix(name: str):
6067
suffix_list = [

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from diffsynth_engine.utils.gguf import gguf_inference
1919
from diffsynth_engine.utils.fp8_linear import fp8_inference
2020
from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
21-
from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
21+
from diffsynth_engine.utils.parallel import (
22+
cfg_parallel,
23+
cfg_parallel_unshard,
24+
sequence_parallel,
25+
sequence_parallel_unshard,
26+
)
2227
from diffsynth_engine.utils import logging
2328

2429

@@ -323,12 +328,12 @@ def forward(self, x, t_emb, rope_emb, image_emb=None):
323328

324329
class FluxDiT(PreTrainedModel):
325330
converter = FluxDiTStateDictConverter()
331+
_supports_parallelization = True
326332

327333
def __init__(
328334
self,
329335
in_channel: int = 64,
330336
attn_impl: Optional[str] = None,
331-
use_usp: bool = False,
332337
device: str = "cuda:0",
333338
dtype: torch.dtype = torch.bfloat16,
334339
):
@@ -354,8 +359,6 @@ def __init__(
354359
self.final_norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
355360
self.final_proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
356361

357-
self.use_usp = use_usp
358-
359362
def patchify(self, hidden_states):
360363
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
361364
return hidden_states
@@ -398,6 +401,8 @@ def forward(
398401
**kwargs,
399402
):
400403
h, w = hidden_states.shape[-2:]
404+
if image_ids is None:
405+
image_ids = self.prepare_image_ids(hidden_states)
401406
controlnet_double_block_output = (
402407
controlnet_double_block_output if controlnet_double_block_output is not None else ()
403408
)
@@ -406,10 +411,24 @@ def forward(
406411
)
407412

408413
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
409-
with fp8_inference(fp8_linear_enabled), gguf_inference():
410-
if image_ids is None:
411-
image_ids = self.prepare_image_ids(hidden_states)
412-
414+
with (
415+
fp8_inference(fp8_linear_enabled),
416+
gguf_inference(),
417+
cfg_parallel(
418+
(
419+
hidden_states,
420+
timestep,
421+
prompt_emb,
422+
pooled_prompt_emb,
423+
image_emb,
424+
guidance,
425+
text_ids,
426+
image_ids,
427+
*controlnet_double_block_output,
428+
*controlnet_single_block_output,
429+
)
430+
),
431+
):
413432
# warning: keep the order of time_embedding + guidance_embedding + pooled_text_embedding
414433
# addition of floating point numbers does not meet commutative law
415434
conditioning = self.time_embedder(timestep, hidden_states.dtype)
@@ -439,7 +458,6 @@ def forward(
439458
*(1 for _ in controlnet_double_block_output),
440459
*(1 for _ in controlnet_single_block_output),
441460
),
442-
enabled=self.use_usp,
443461
):
444462
hidden_states = self.x_embedder(hidden_states)
445463
prompt_emb = self.context_embedder(prompt_emb)
@@ -465,6 +483,7 @@ def forward(
465483
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
466484

467485
hidden_states = self.unpatchify(hidden_states, h, w)
486+
(hidden_states,) = cfg_parallel_unshard((hidden_states,))
468487
return hidden_states
469488

470489
@classmethod
@@ -475,7 +494,6 @@ def from_state_dict(
475494
dtype: torch.dtype,
476495
in_channel: int = 64,
477496
attn_impl: Optional[str] = None,
478-
use_usp: bool = False,
479497
):
480498
with no_init_weights():
481499
model = torch.nn.utils.skip_init(
@@ -484,9 +502,11 @@ def from_state_dict(
484502
dtype=dtype,
485503
in_channel=in_channel,
486504
attn_impl=attn_impl,
487-
use_usp=use_usp,
488505
)
489506
model = model.requires_grad_(False) # for loading gguf
490507
model.load_state_dict(state_dict, assign=True)
491508
model.to(device=device, dtype=dtype, non_blocking=True)
492509
return model
510+
511+
def get_fsdp_modules(self):
512+
return ["blocks", "single_blocks"]

diffsynth_engine/models/wan/wan_dit.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
WAN_DIT_14B_FLF2V_CONFIG_FILE,
1717
)
1818
from diffsynth_engine.utils.gguf import gguf_inference
19-
from diffsynth_engine.utils.parallel import sequence_parallel, sequence_parallel_unshard
19+
from diffsynth_engine.utils.parallel import (
20+
cfg_parallel,
21+
cfg_parallel_unshard,
22+
sequence_parallel,
23+
sequence_parallel_unshard,
24+
)
2025

2126
T5_TOKEN_NUM = 512
2227
FLF_TOKEN_NUM = 257 * 2
@@ -244,6 +249,7 @@ def convert(self, state_dict):
244249

245250
class WanDiT(PreTrainedModel):
246251
converter = WanDiTStateDictConverter()
252+
_supports_parallelization = True
247253

248254
def __init__(
249255
self,
@@ -260,7 +266,6 @@ def __init__(
260266
has_image_input: bool,
261267
flf_pos_emb: bool = False,
262268
attn_impl: Optional[str] = None,
263-
use_usp: bool = False,
264269
device: str = "cpu",
265270
dtype: torch.dtype = torch.bfloat16,
266271
):
@@ -303,8 +308,6 @@ def __init__(
303308
if has_image_input:
304309
self.img_emb = MLP(1280, dim, flf_pos_emb, device=device, dtype=dtype) # clip_feature_dim = 1280
305310

306-
self.use_usp = use_usp
307-
308311
def patchify(self, x: torch.Tensor):
309312
x = self.patch_embedding(x) # b c f h w -> b 4c f h/2 w/2
310313
grid_size = x.shape[2:]
@@ -331,7 +334,10 @@ def forward(
331334
clip_feature: Optional[torch.Tensor] = None, # clip_vision_encoder(img)
332335
y: Optional[torch.Tensor] = None, # vae_encoder(img)
333336
):
334-
with gguf_inference():
337+
with (
338+
gguf_inference(),
339+
cfg_parallel((x, context, timestep, clip_feature, y)),
340+
):
335341
t = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, timestep))
336342
t_mod = self.time_projection(t).unflatten(1, (6, self.dim))
337343
context = self.text_embedding(context)
@@ -353,12 +359,13 @@ def forward(
353359
.to(x.device)
354360
)
355361

356-
with sequence_parallel([x, freqs], seq_dims=(1, 0), enabled=self.use_usp):
362+
with sequence_parallel((x, freqs), seq_dims=(1, 0)):
357363
for block in self.blocks:
358364
x = block(x, context, t_mod, freqs)
359365
x = self.head(x, t)
360366
(x,) = sequence_parallel_unshard((x,), seq_dims=(1,), seq_lens=(f * h * w,))
361367
x = self.unpatchify(x, (f, h, w))
368+
(x,) = cfg_parallel_unshard((x,))
362369
return x
363370

364371
@classmethod
@@ -369,7 +376,6 @@ def from_state_dict(
369376
dtype,
370377
model_type="1.3b-t2v",
371378
attn_impl: Optional[str] = None,
372-
use_usp=False,
373379
assign=True,
374380
):
375381
if model_type == "1.3b-t2v":
@@ -383,9 +389,7 @@ def from_state_dict(
383389
else:
384390
raise ValueError(f"Unsupported model type: {model_type}")
385391
with no_init_weights():
386-
model = torch.nn.utils.skip_init(
387-
cls, **config, device=device, dtype=dtype, attn_impl=attn_impl, use_usp=use_usp
388-
)
392+
model = torch.nn.utils.skip_init(cls, **config, device=device, dtype=dtype, attn_impl=attn_impl)
389393
model = model.requires_grad_(False)
390394
model.load_state_dict(state_dict, assign=assign)
391395
model.to(device=device, dtype=dtype)
@@ -467,3 +471,6 @@ def get_tp_plan(self):
467471
}
468472
)
469473
return tp_plan
474+
475+
def get_fsdp_modules(self):
476+
return ["blocks"]

diffsynth_engine/pipelines/base.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
vae_tiled: bool = False,
3232
vae_tile_size: int = -1,
3333
vae_tile_stride: int = -1,
34-
device="cuda:0",
34+
device="cuda",
3535
dtype=torch.float16,
3636
):
3737
super().__init__()
@@ -47,15 +47,15 @@ def __init__(
4747
def from_pretrained(
4848
cls,
4949
model_path_or_config: str | os.PathLike | ModelConfig,
50-
device: str = "cuda:0",
50+
device: str = "cuda",
5151
dtype: torch.dtype = torch.float16,
5252
offload_mode: str | None = None,
5353
) -> "BasePipeline":
5454
raise NotImplementedError()
5555

5656
@classmethod
5757
def from_state_dict(
58-
cls, state_dict: Dict[str, torch.Tensor], device: str = "cuda:0", dtype: torch.dtype = torch.float16
58+
cls, state_dict: Dict[str, torch.Tensor], device: str = "cuda", dtype: torch.dtype = torch.float16
5959
) -> "BasePipeline":
6060
raise NotImplementedError()
6161

@@ -269,21 +269,18 @@ def enable_cpu_offload(self, offload_mode: str):
269269
logger.warning("must set an non cpu device for pipeline before calling enable_cpu_offload")
270270
return
271271
if offload_mode == "cpu_offload":
272-
self.enable_model_cpu_offload()
272+
self._enable_model_cpu_offload()
273273
elif offload_mode == "sequential_cpu_offload":
274-
self.enable_sequential_cpu_offload()
274+
self._enable_sequential_cpu_offload()
275275

276-
def enable_model_cpu_offload(self):
276+
def _enable_model_cpu_offload(self):
277277
for model_name in self.model_names:
278278
model = getattr(self, model_name)
279279
if model is not None:
280280
model.to("cpu")
281281
self.offload_mode = "cpu_offload"
282282

283-
def enable_sequential_cpu_offload(self):
284-
if self.device == "cpu" or self.device == "mps":
285-
logger.warning("must set an non cpu device for pipeline before calling enable_sequential_cpu_offload")
286-
return
283+
def _enable_sequential_cpu_offload(self):
287284
for model_name in self.model_names:
288285
model = getattr(self, model_name)
289286
if model is not None:

0 commit comments

Comments
 (0)