|
23 | 23 | ContextParallelInput, |
24 | 24 | ContextParallelModelPlan, |
25 | 25 | ContextParallelOutput, |
26 | | - ParallelConfig, |
| 26 | + _InternalParallelConfig, |
27 | 27 | ) |
28 | | -from ..models.attention_dispatch import _parallel_context |
29 | 28 | from ..utils import get_logger |
30 | 29 | from ..utils.torch_utils import unwrap_module |
31 | 30 | from .hooks import HookRegistry, ModelHook |
32 | 31 |
|
33 | 32 |
|
34 | 33 | logger = get_logger(__name__) # pylint: disable=invalid-name |
35 | 34 |
|
36 | | -_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook" |
37 | 35 | _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}" |
38 | 36 | _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}" |
39 | 37 |
|
@@ -76,7 +74,7 @@ def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None) |
76 | 74 |
|
77 | 75 | def apply_context_parallel( |
78 | 76 | module: torch.nn.Module, |
79 | | - parallel_config: ParallelConfig, |
| 77 | + parallel_config: _InternalParallelConfig, |
80 | 78 | plan: Dict[str, ContextParallelModelPlan], |
81 | 79 | ) -> None: |
82 | 80 | """Apply context parallel on a model.""" |
@@ -105,45 +103,26 @@ def apply_context_parallel( |
105 | 103 | registry = HookRegistry.check_if_exists_or_initialize(m) |
106 | 104 | registry.register_hook(hook, hook_name) |
107 | 105 |
|
108 | | - # HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward |
109 | | - # diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks |
110 | | - # available in pytorch to set the parallel context before/after the forward/backward pass. |
111 | | - # It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet. |
112 | | - # The previous/older implementation simply did this: |
113 | | - # def new_forward(self, ...): |
114 | | - # with _parallel_context(parallel_config): |
115 | | - # return self.fn_ref.original_forward(*args, **kwargs) |
116 | | - # TODO: ask help from Pytorch team on how to improve this |
117 | | - @torch.compiler.disable |
118 | | - def forward_pre_hook(module, args): |
119 | | - module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) |
120 | | - module._diffusers_parallel_config_setter_context.__enter__() |
121 | | - |
122 | | - @torch.compiler.disable |
123 | | - def forward_hook(module, args, output): |
124 | | - if module._diffusers_parallel_config_setter_context is not None: |
125 | | - module._diffusers_parallel_config_setter_context.__exit__(None, None, None) |
126 | | - module._diffusers_parallel_config_setter_context = None |
127 | | - |
128 | | - @torch.compiler.disable |
129 | | - def backward_pre_hook(module, grad_output): |
130 | | - module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config) |
131 | | - module._diffusers_parallel_config_setter_context.__enter__() |
132 | | - |
133 | | - @torch.compiler.disable |
134 | | - def backward_hook(module, grad_output, grad_input): |
135 | | - if module._diffusers_parallel_config_setter_context is not None: |
136 | | - module._diffusers_parallel_config_setter_context.__exit__(None, None, None) |
137 | | - module._diffusers_parallel_config_setter_context = None |
138 | | - |
139 | | - module.register_forward_pre_hook(forward_pre_hook) |
140 | | - module.register_forward_hook(forward_hook) |
141 | | - module.register_full_backward_pre_hook(backward_pre_hook) |
142 | | - module.register_full_backward_hook(backward_hook) |
| 106 | + |
| 107 | +def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None: |
| 108 | + for module_id, cp_model_plan in plan.items(): |
| 109 | + submodule = _get_submodule_by_name(module, module_id) |
| 110 | + if not isinstance(submodule, list): |
| 111 | + submodule = [submodule] |
| 112 | + |
| 113 | + for m in submodule: |
| 114 | + registry = HookRegistry.check_if_exists_or_initialize(m) |
| 115 | + if isinstance(cp_model_plan, dict): |
| 116 | + hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id) |
| 117 | + elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)): |
| 118 | + hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id) |
| 119 | + else: |
| 120 | + raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}") |
| 121 | + registry.remove_hook(hook_name) |
143 | 122 |
|
144 | 123 |
|
145 | 124 | class ContextParallelSplitHook(ModelHook): |
146 | | - def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None: |
| 125 | + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None: |
147 | 126 | super().__init__() |
148 | 127 | self.metadata = metadata |
149 | 128 | self.parallel_config = parallel_config |
@@ -228,7 +207,7 @@ def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> |
228 | 207 |
|
229 | 208 |
|
230 | 209 | class ContextParallelGatherHook(ModelHook): |
231 | | - def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None: |
| 210 | + def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None: |
232 | 211 | super().__init__() |
233 | 212 | self.metadata = metadata |
234 | 213 | self.parallel_config = parallel_config |
|
0 commit comments