Skip to content

Commit fa5d017

Browse files
committed
refactor
1 parent 1ffc03e commit fa5d017

File tree

5 files changed

+166
-101
lines changed

5 files changed

+166
-101
lines changed

src/diffusers/hooks/context_parallel.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,15 @@
2323
ContextParallelInput,
2424
ContextParallelModelPlan,
2525
ContextParallelOutput,
26-
ParallelConfig,
26+
_InternalParallelConfig,
2727
)
28-
from ..models.attention_dispatch import _parallel_context
2928
from ..utils import get_logger
3029
from ..utils.torch_utils import unwrap_module
3130
from .hooks import HookRegistry, ModelHook
3231

3332

3433
logger = get_logger(__name__) # pylint: disable=invalid-name
3534

36-
_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
3735
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
3836
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
3937

@@ -76,7 +74,7 @@ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None)
7674

7775
def apply_context_parallel(
7876
module: torch.nn.Module,
79-
parallel_config: ParallelConfig,
77+
parallel_config: _InternalParallelConfig,
8078
plan: Dict[str, ContextParallelModelPlan],
8179
) -> None:
8280
"""Apply context parallel on a model."""
@@ -105,45 +103,26 @@ def apply_context_parallel(
105103
registry = HookRegistry.check_if_exists_or_initialize(m)
106104
registry.register_hook(hook, hook_name)
107105

108-
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
109-
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
110-
# available in pytorch to set the parallel context before/after the forward/backward pass.
111-
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
112-
# The previous/older implementation simply did this:
113-
# def new_forward(self, ...):
114-
# with _parallel_context(parallel_config):
115-
# return self.fn_ref.original_forward(*args, **kwargs)
116-
# TODO: ask help from Pytorch team on how to improve this
117-
@torch.compiler.disable
118-
def forward_pre_hook(module, args):
119-
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
120-
module._diffusers_parallel_config_setter_context.__enter__()
121-
122-
@torch.compiler.disable
123-
def forward_hook(module, args, output):
124-
if module._diffusers_parallel_config_setter_context is not None:
125-
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
126-
module._diffusers_parallel_config_setter_context = None
127-
128-
@torch.compiler.disable
129-
def backward_pre_hook(module, grad_output):
130-
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
131-
module._diffusers_parallel_config_setter_context.__enter__()
132-
133-
@torch.compiler.disable
134-
def backward_hook(module, grad_output, grad_input):
135-
if module._diffusers_parallel_config_setter_context is not None:
136-
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
137-
module._diffusers_parallel_config_setter_context = None
138-
139-
module.register_forward_pre_hook(forward_pre_hook)
140-
module.register_forward_hook(forward_hook)
141-
module.register_full_backward_pre_hook(backward_pre_hook)
142-
module.register_full_backward_hook(backward_hook)
106+
107+
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
108+
for module_id, cp_model_plan in plan.items():
109+
submodule = _get_submodule_by_name(module, module_id)
110+
if not isinstance(submodule, list):
111+
submodule = [submodule]
112+
113+
for m in submodule:
114+
registry = HookRegistry.check_if_exists_or_initialize(m)
115+
if isinstance(cp_model_plan, dict):
116+
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
117+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
118+
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
119+
else:
120+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
121+
registry.remove_hook(hook_name)
143122

144123

145124
class ContextParallelSplitHook(ModelHook):
146-
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
125+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
147126
super().__init__()
148127
self.metadata = metadata
149128
self.parallel_config = parallel_config
@@ -228,7 +207,7 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) ->
228207

229208

230209
class ContextParallelGatherHook(ModelHook):
231-
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
210+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
232211
super().__init__()
233212
self.metadata = metadata
234213
self.parallel_config = parallel_config

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_import_structure = {}
2626

2727
if is_torch_available():
28+
_import_structure["_modeling_parallel"] = ["ParallelConfig"]
2829
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
2930
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
3031
_import_structure["auto_model"] = ["AutoModel"]
@@ -112,6 +113,7 @@
112113

113114
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
114115
if is_torch_available():
116+
from ._modeling_parallel import ParallelConfig
115117
from .adapter import MultiAdapter, T2IAdapter
116118
from .attention_dispatch import AttentionBackendName, attention_backend
117119
from .auto_model import AutoModel

src/diffusers/models/_modeling_parallel.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@
3535

3636
@dataclass
3737
class ParallelConfig:
38+
ring_degree: Optional[int] = None
39+
ulysses_degree: Optional[int] = None
40+
41+
def __post_init__(self):
42+
if self.ring_degree is None:
43+
self.ring_degree = 1
44+
if self.ulysses_degree is None:
45+
self.ulysses_degree = 1
46+
47+
48+
@dataclass
49+
class _InternalParallelConfig:
3850
rank: int
3951
world_size: int
4052
ring_degree: int

0 commit comments

Comments
 (0)