Skip to content

Commit 7973626

Browse files
committed
refactor
1 parent 51fed50 commit 7973626

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

src/diffusers/hooks/context_parallel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
logger = get_logger(__name__) # pylint: disable=invalid-name
3535

3636
_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
37-
_CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE = "cp_input---{}"
38-
_CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
37+
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
38+
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
3939

4040

4141
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
@@ -92,14 +92,14 @@ def apply_context_parallel(
9292
for m in submodule:
9393
if isinstance(cp_model_plan, dict):
9494
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
95-
hook_name = _CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE.format(module_id)
95+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
9696
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
9797
if isinstance(cp_model_plan, ContextParallelOutput):
9898
cp_model_plan = [cp_model_plan]
9999
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
100100
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
101101
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
102-
hook_name = _CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE.format(module_id)
102+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
103103
else:
104104
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
105105
registry = HookRegistry.check_if_exists_or_initialize(m)

src/diffusers/models/attention_dispatch.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,11 @@ def forward(
751751
query: torch.Tensor,
752752
key: torch.Tensor,
753753
value: torch.Tensor,
754+
attn_mask: Optional[torch.Tensor],
755+
dropout_p: float,
756+
scale: Optional[float],
757+
is_causal: bool,
758+
enable_gqa: bool,
754759
return_lse: bool,
755760
op: torch.autograd.Function,
756761
):
@@ -773,7 +778,7 @@ def forward(
773778
value = kv[key.numel() :].reshape_as(value)
774779
next_rank = (next_rank + 1) % world_size
775780

776-
out, lse = op.apply(query, key, value, None, 0.0, None, False, False, True)
781+
out, lse = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, True)
777782

778783
if parallel_config.convert_to_fp32:
779784
out = out.to(torch.float32)
@@ -806,6 +811,11 @@ def forward(
806811
query: torch.Tensor,
807812
key: torch.Tensor,
808813
value: torch.Tensor,
814+
attn_mask: Optional[torch.Tensor],
815+
dropout_p: float,
816+
scale: Optional[float],
817+
is_causal: bool,
818+
enable_gqa: bool,
809819
return_lse: bool,
810820
op: torch.autograd.Function,
811821
):
@@ -823,7 +833,7 @@ def forward(
823833
query, key, value = (funcol.all_to_all_single(x, None, None, group=group).wait() for x in (query, key, value))
824834
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
825835

826-
out = op.apply(query, key, value, None, 0.0, None, False, False, return_lse)
836+
out = op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
827837
if return_lse:
828838
out, lse, *_ = out
829839

@@ -872,9 +882,13 @@ def _templated_context_parallel_attention(
872882
parallel_config = _AttentionBackendRegistry._parallel_config
873883
# TODO: add support for unified attention with ring/ulysses degree both being > 1
874884
if parallel_config.ring_degree > 1:
875-
return TemplatedRingAttention.apply(query, key, value, return_lse, op)
885+
return TemplatedRingAttention.apply(
886+
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
887+
)
876888
elif parallel_config.ulysses_degree > 1:
877-
return TemplatedUlyssesAttention.apply(query, key, value, return_lse, op)
889+
return TemplatedUlyssesAttention.apply(
890+
query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse, op
891+
)
878892
else:
879893
return op.apply(query, key, value, attn_mask, dropout_p, scale, is_causal, enable_gqa, return_lse)
880894

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,6 +1535,10 @@ def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan=
15351535
device=device,
15361536
cp_mesh=cp_mesh,
15371537
)
1538+
if cp_plan is None and self._cp_plan is None:
1539+
raise ValueError(
1540+
"`cp_plan` must be provided either as an argument or set in the model's `_cp_plan` attribute."
1541+
)
15381542
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
15391543

15401544
apply_context_parallel(self, parallel_config, cp_plan)

0 commit comments

Comments
 (0)