Skip to content

Commit 51fed50

Browse files
committed
update
1 parent 84d2c84 commit 51fed50

File tree

6 files changed

+972
-149
lines changed

6 files changed

+972
-149
lines changed

src/diffusers/hooks/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818
if is_torch_available():
19+
from .context_parallel import apply_context_parallel
1920
from .faster_cache import FasterCacheConfig, apply_faster_cache
2021
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
2122
from .group_offloading import apply_group_offloading
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from dataclasses import dataclass
17+
from typing import Dict, List, Type, Union
18+
19+
import torch
20+
import torch.distributed._functional_collectives as funcol
21+
22+
from ..models._modeling_parallel import (
23+
ContextParallelInput,
24+
ContextParallelModelPlan,
25+
ContextParallelOutput,
26+
ParallelConfig,
27+
)
28+
from ..models.attention_dispatch import _parallel_context
29+
from ..utils import get_logger
30+
from ..utils.torch_utils import unwrap_module
31+
from .hooks import HookRegistry, ModelHook
32+
33+
34+
logger = get_logger(__name__) # pylint: disable=invalid-name
35+
36+
_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
37+
_CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE = "cp_input---{}"
38+
_CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
39+
40+
41+
# TODO(aryan): consolidate with ._helpers.TransformerBlockMetadata
42+
@dataclass
43+
class ModuleForwardMetadata:
44+
cached_parameter_indices: Dict[str, int] = None
45+
_cls: Type = None
46+
47+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
48+
kwargs = kwargs or {}
49+
50+
if identifier in kwargs:
51+
return kwargs[identifier], True, None
52+
53+
if self.cached_parameter_indices is not None:
54+
index = self.cached_parameter_indices.get(identifier, None)
55+
if index is None:
56+
raise ValueError(f"Parameter '{identifier}' not found in cached indices.")
57+
return args[index], False, index
58+
59+
if self._cls is None:
60+
raise ValueError("Model class is not set for metadata.")
61+
62+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
63+
parameters = parameters[1:] # skip `self`
64+
self.cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
65+
66+
if identifier not in self.cached_parameter_indices:
67+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
68+
69+
index = self.cached_parameter_indices[identifier]
70+
71+
if index >= len(args):
72+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
73+
74+
return args[index], False, index
75+
76+
77+
def apply_context_parallel(
78+
module: torch.nn.Module,
79+
parallel_config: ParallelConfig,
80+
plan: Dict[str, ContextParallelModelPlan],
81+
) -> None:
82+
"""Apply context parallel on a model."""
83+
logger.debug(f"Applying context parallel with CP mesh: {parallel_config.cp_mesh} and plan: {plan}")
84+
85+
for module_id, cp_model_plan in plan.items():
86+
submodule = _get_submodule_by_name(module, module_id)
87+
if not isinstance(submodule, list):
88+
submodule = [submodule]
89+
90+
logger.debug(f"Applying ContextParallelHook to {module_id=} identifying a total of {len(submodule)} modules")
91+
92+
for m in submodule:
93+
if isinstance(cp_model_plan, dict):
94+
hook = ContextParallelSplitHook(cp_model_plan, parallel_config)
95+
hook_name = _CONTEXT_PARALLEL_SUBMODULE_INPUT_HOOK_TEMPLATE.format(module_id)
96+
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
97+
if isinstance(cp_model_plan, ContextParallelOutput):
98+
cp_model_plan = [cp_model_plan]
99+
if not all(isinstance(x, ContextParallelOutput) for x in cp_model_plan):
100+
raise ValueError(f"Expected all elements of cp_model_plan to be CPOutput, but got {cp_model_plan}")
101+
hook = ContextParallelGatherHook(cp_model_plan, parallel_config)
102+
hook_name = _CONTEXT_PARALLEL_SUBMODULE_OUTPUT_HOOK_TEMPLATE.format(module_id)
103+
else:
104+
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
105+
registry = HookRegistry.check_if_exists_or_initialize(m)
106+
registry.register_hook(hook, hook_name)
107+
108+
registry = HookRegistry.check_if_exists_or_initialize(module)
109+
hook = ContextParallelModelHook(parallel_config)
110+
registry.register_hook(hook, _CONTEXT_PARALLEL_MODEL_HOOK)
111+
112+
113+
class ContextParallelModelHook(ModelHook):
114+
def __init__(self, parallel_config: ParallelConfig) -> None:
115+
super().__init__()
116+
self.parallel_config = parallel_config
117+
118+
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
119+
with _parallel_context(self.parallel_config):
120+
return self.fn_ref.original_forward(*args, **kwargs)
121+
122+
123+
class ContextParallelSplitHook(ModelHook):
124+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
125+
super().__init__()
126+
self.metadata = metadata
127+
self.parallel_config = parallel_config
128+
self.module_forward_metadata = None
129+
130+
def initialize_hook(self, module):
131+
cls = unwrap_module(module).__class__
132+
self.module_forward_metadata = ModuleForwardMetadata(_cls=cls)
133+
return module
134+
135+
def pre_forward(self, module, *args, **kwargs):
136+
args_list = list(args)
137+
138+
for name, cpm in self.metadata.items():
139+
if isinstance(cpm, ContextParallelInput) and cpm.split_output:
140+
continue
141+
142+
# Maybe the parameter was passed as a keyword argument
143+
input_val, is_kwarg, index = self.module_forward_metadata._get_parameter_from_args_kwargs(
144+
name, args_list, kwargs
145+
)
146+
147+
if input_val is None:
148+
continue
149+
150+
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
151+
# the output instead of input for a particular layer by setting split_output=True
152+
if isinstance(input_val, torch.Tensor):
153+
input_val = self._prepare_cp_input(input_val, cpm)
154+
elif isinstance(input_val, (list, tuple)):
155+
if len(input_val) != len(cpm):
156+
raise ValueError(
157+
f"Expected input model plan to have {len(input_val)} elements, but got {len(cpm)}."
158+
)
159+
sharded_input_val = []
160+
for i, x in enumerate(input_val):
161+
if torch.is_tensor(x) and not cpm[i].split_output:
162+
x = self._prepare_cp_input(x, cpm[i])
163+
sharded_input_val.append(x)
164+
input_val = sharded_input_val
165+
else:
166+
raise ValueError(f"Unsupported input type: {type(input_val)}")
167+
168+
if is_kwarg:
169+
kwargs[name] = input_val
170+
elif index is not None and index < len(args_list):
171+
args_list[index] = input_val
172+
else:
173+
raise ValueError(
174+
f"An unexpected error occurred while processing the input '{name}'. Please open an "
175+
f"issue at https://github.com/huggingface/diffusers/issues and provide a minimal reproducible "
176+
f"example along with the full stack trace."
177+
)
178+
179+
return tuple(args_list), kwargs
180+
181+
def post_forward(self, module, output):
182+
is_tensor = isinstance(output, torch.Tensor)
183+
is_tensor_list = isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)
184+
185+
if not is_tensor and not is_tensor_list:
186+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
187+
188+
output = [output] if is_tensor else list(output)
189+
for index, cpm in self.metadata.items():
190+
if not isinstance(cpm, ContextParallelInput) or not cpm.split_output:
191+
continue
192+
if index >= len(output):
193+
raise ValueError(f"Index {index} out of bounds for output of length {len(output)}.")
194+
current_output = output[index]
195+
current_output = self._prepare_cp_input(current_output, cpm)
196+
output[index] = current_output
197+
198+
return output[0] if is_tensor else tuple(output)
199+
200+
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
201+
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
202+
raise ValueError(
203+
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
204+
)
205+
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
206+
207+
208+
class ContextParallelGatherHook(ModelHook):
209+
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
210+
super().__init__()
211+
self.metadata = metadata
212+
self.parallel_config = parallel_config
213+
214+
def post_forward(self, module, output):
215+
is_tensor = isinstance(output, torch.Tensor)
216+
217+
if is_tensor:
218+
output = [output]
219+
elif not (isinstance(output, (list, tuple)) and all(isinstance(x, torch.Tensor) for x in output)):
220+
raise ValueError(f"Expected output to be a tensor or a list/tuple of tensors, but got {type(output)}.")
221+
222+
output = list(output)
223+
224+
if len(output) != len(self.metadata):
225+
raise ValueError(f"Expected output to have {len(self.metadata)} elements, but got {len(output)}.")
226+
227+
for i, cpm in enumerate(self.metadata):
228+
if cpm is None:
229+
continue
230+
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
231+
232+
return output[0] if is_tensor else tuple(output)
233+
234+
235+
class EquipartitionSharder:
236+
@classmethod
237+
@torch.compiler.disable
238+
def shard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
239+
assert tensor.size()[dim] % mesh.size() == 0
240+
return tensor.chunk(mesh.size(), dim=dim)[mesh.get_rank()]
241+
242+
@classmethod
243+
@torch.compiler.disable
244+
def unshard(cls, tensor: torch.Tensor, dim: int, mesh: torch.distributed.device_mesh.DeviceMesh) -> torch.Tensor:
245+
tensor = tensor.contiguous()
246+
tensor = funcol.all_gather_tensor(tensor, dim, group=mesh.get_group())
247+
return tensor
248+
249+
250+
def _get_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
251+
if name.count("*") > 1:
252+
raise ValueError("Wildcard '*' can only be used once in the name")
253+
return _find_submodule_by_name(model, name)
254+
255+
256+
def _find_submodule_by_name(model: torch.nn.Module, name: str) -> Union[torch.nn.Module, List[torch.nn.Module]]:
257+
if name == "":
258+
return model
259+
first_atom, remaining_name = name.split(".", 1) if "." in name else (name, "")
260+
if first_atom == "*":
261+
if not isinstance(model, torch.nn.ModuleList):
262+
raise ValueError("Wildcard '*' can only be used with ModuleList")
263+
submodules = []
264+
for submodule in model:
265+
subsubmodules = _find_submodule_by_name(submodule, remaining_name)
266+
if not isinstance(subsubmodules, list):
267+
subsubmodules = [subsubmodules]
268+
submodules.extend(subsubmodules)
269+
return submodules
270+
else:
271+
if hasattr(model, first_atom):
272+
submodule = getattr(model, first_atom)
273+
return _find_submodule_by_name(submodule, remaining_name)
274+
else:
275+
raise ValueError(f"'{first_atom}' is not a submodule of '{model.__class__.__name__}'")
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Experimental parallelism support for Diffusers.
2+
# Copyright 2025 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from dataclasses import dataclass
17+
from typing import Dict, List, Literal, Optional, Tuple, Union
18+
19+
import torch
20+
21+
from ..utils import get_logger
22+
23+
24+
logger = get_logger(__name__) # pylint: disable=invalid-name
25+
26+
27+
# TODO(aryan): add support for the following:
28+
# - Unified Attention
29+
# - More dispatcher attention backends
30+
# - CFG/Data Parallel
31+
# - Tensor Parallel
32+
33+
34+
@dataclass
35+
class ParallelConfig:
36+
rank: int
37+
world_size: int
38+
ring_degree: int
39+
ulysses_degree: int
40+
device: torch.device
41+
cp_mesh: torch.distributed.device_mesh.DeviceMesh
42+
43+
# Whether to convert output and LSE to float32 for ring attention numerical stability
44+
convert_to_fp32: bool = True
45+
# TODO: support alltoall
46+
rotate_method: Literal["allgather", "alltoall"] = "allgather"
47+
48+
_flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
49+
_ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
50+
_ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
51+
_ring_local_rank: int = None
52+
_ulysses_local_rank: int = None
53+
54+
def __post_init__(self):
55+
if self.rotate_method != "allgather":
56+
raise ValueError(f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}.")
57+
if self._flattened_mesh is None:
58+
self._flattened_mesh = self.cp_mesh._flatten()
59+
if self._ring_mesh is None:
60+
self._ring_mesh = self.cp_mesh["ring"]
61+
if self._ulysses_mesh is None:
62+
self._ulysses_mesh = self.cp_mesh["ulysses"]
63+
if self._ring_local_rank is None:
64+
self._ring_local_rank = self._ring_mesh.get_local_rank()
65+
if self._ulysses_local_rank is None:
66+
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
67+
68+
69+
@dataclass(frozen=True)
70+
class ContextParallelInput:
71+
split_dim: int
72+
expected_dims: Optional[int] = None
73+
split_output: bool = False
74+
75+
def __repr__(self):
76+
return f"ContextParallelInput(split_dim={self.split_dim}, expected_dims={self.expected_dims}, split_output={self.split_output})"
77+
78+
79+
@dataclass(frozen=True)
80+
class ContextParallelOutput:
81+
gather_dim: int
82+
expected_dims: Optional[int] = None
83+
84+
def __repr__(self):
85+
return f"ContextParallelOutput(gather_dim={self.gather_dim}, expected_dims={self.expected_dims})"
86+
87+
88+
# A dictionary where keys denote the input to be split across context parallel region, and the
89+
# value denotes the sharding configuration.
90+
# If the key is a string, it denotes the name of the parameter in the forward function.
91+
# If the key is an integer, split_output must be set to True, and it denotes the index of the output
92+
# to be split across context parallel region.
93+
ContextParallelInputType = Dict[
94+
Union[str, int], Union[ContextParallelInput, List[ContextParallelInput], Tuple[ContextParallelInput, ...]]
95+
]
96+
97+
# A dictionary where keys denote the output to be gathered across context parallel region, and the
98+
# value denotes the gathering configuration.
99+
ContextParallelOutputType = Union[
100+
ContextParallelOutput, List[ContextParallelOutput], Tuple[ContextParallelOutput, ...]
101+
]
102+
103+
# A dictionary where keys denote the module id, and the value denotes how the inputs/outputs of
104+
# the module should be split/gathered across context parallel region.
105+
ContextParallelModelPlan = Dict[str, Union[ContextParallelInputType, ContextParallelOutputType]]

0 commit comments

Comments
 (0)