Skip to content

Commit e7e92b4

Browse files
authored
fix cfg parallel (#104)
1 parent eadfdda commit e7e92b4

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def forward(
411411
)
412412

413413
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
414+
use_cfg = hidden_states.shape[0] > 1
414415
with (
415416
fp8_inference(fp8_linear_enabled),
416417
gguf_inference(),
@@ -426,7 +427,8 @@ def forward(
426427
image_ids,
427428
*controlnet_double_block_output,
428429
*controlnet_single_block_output,
429-
)
430+
),
431+
use_cfg=use_cfg,
430432
),
431433
):
432434
# warning: keep the order of time_embedding + guidance_embedding + pooled_text_embedding
@@ -483,7 +485,7 @@ def forward(
483485
(hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,))
484486

485487
hidden_states = self.unpatchify(hidden_states, h, w)
486-
(hidden_states,) = cfg_parallel_unshard((hidden_states,))
488+
(hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg)
487489
return hidden_states
488490

489491
@classmethod

diffsynth_engine/utils/parallel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,9 @@ def __init__(
330330
device: str = "cuda",
331331
):
332332
current_method = mp.get_start_method(allow_none=True)
333-
if current_method is None or current_method != 'spawn':
333+
if current_method is None or current_method != "spawn":
334334
try:
335-
mp.set_start_method('spawn')
335+
mp.set_start_method("spawn")
336336
except RuntimeError as e:
337337
raise RuntimeError("Failed to set start method to spawn:", e)
338338
super().__init__()
@@ -404,8 +404,8 @@ def __del__(self):
404404

405405

406406
@contextmanager
407-
def cfg_parallel(tensors: List[torch.Tensor]):
408-
if get_cfg_world_size() == 1:
407+
def cfg_parallel(tensors: List[torch.Tensor], use_cfg=True):
408+
if get_cfg_world_size() == 1 or not use_cfg:
409409
yield
410410
return
411411

@@ -426,8 +426,8 @@ def cfg_parallel(tensors: List[torch.Tensor]):
426426
tensor.copy_(original_tensor)
427427

428428

429-
def cfg_parallel_unshard(tensors: List[torch.Tensor]):
430-
if get_cfg_world_size() == 1:
429+
def cfg_parallel_unshard(tensors: List[torch.Tensor], use_cfg=True):
430+
if get_cfg_world_size() == 1 or not use_cfg:
431431
return tensors
432432

433433
unshard_tensors = []

0 commit comments

Comments
 (0)