Skip to content

Commit 3b5772e

Browse files
feat: add distributer algorithm (#459)
* ring_attn added * fixed typo * adding hook and tweaking tests * add support for diffusers 0.35 and evaluation agent for ring attention #224 * Co-author Co-authored-by: Johanna Sommer <johanna@mail-sommer.com> * distributed setup logic moved from smash.py to server_utils.py --------- Co-authored-by: Johanna Sommer <johanna@mail-sommer.com>
1 parent e053186 commit 3b5772e

File tree

9 files changed

+900
-6
lines changed

9 files changed

+900
-6
lines changed

src/pruna/algorithms/qkv_diffusers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class QKVFusing(PrunaAlgorithmBase):
5050
"deepcache",
5151
"fora",
5252
"torch_compile",
53+
"ring_attn",
5354
]
5455

5556
def model_check_fn(self, model: Any) -> bool:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from pruna.algorithms.ring_attn.ring import RingAttn
16+
from pruna.algorithms.ring_attn.utils.server_utils import DistributedServer
17+
18+
__all__ = ["RingAttn", "DistributedServer"]
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import contextlib
18+
import functools
19+
from collections.abc import Iterable
20+
from types import ModuleType
21+
from typing import Any, List, Optional, Union
22+
23+
import torch
24+
import torch.distributed as dist
25+
from ConfigSpace import CategoricalHyperparameter
26+
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
27+
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
28+
from torch.distributed.tensor.device_mesh import DeviceMesh
29+
30+
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
31+
from pruna.algorithms.base.tags import AlgorithmTag
32+
from pruna.algorithms.ring_attn.utils.ring_utils import RingDistributedContext
33+
from pruna.config.hyperparameters import Boolean
34+
from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper
35+
from pruna.engine.save import SAVE_FUNCTIONS
36+
37+
ring_attention: ModuleType | None = None
38+
39+
with contextlib.suppress(ImportError):
40+
# see "import_algorithm_packages" for further explanation
41+
import torch.distributed.tensor.experimental._attention as ring_attention
42+
43+
44+
class RingAttn(PrunaAlgorithmBase):
45+
"""
46+
Distributed attention on multiple GPUs computation by using the torch native ring attention implementation.
47+
48+
Each GPU stores only its own slice of Q/K/V and participates in a Ring Attention shuffle that lets every query
49+
attend to every key/value. The result is lower KV-cache/activation memory per GPU and higher arithmetic intensity.
50+
"""
51+
52+
algorithm_name: str = "ring_attn"
53+
group_tags: list[AlgorithmTag] = [AlgorithmTag.KERNEL]
54+
save_fn = SAVE_FUNCTIONS.reapply
55+
references = {
56+
"Implementation": "https://docs.pytorch.org/tutorials/prototype/context_parallel.html",
57+
"Paper": "https://arxiv.org/pdf/2310.01889",
58+
}
59+
tokenizer_required: bool = False
60+
processor_required: bool = False
61+
runs_on: list[str] = ["cuda"]
62+
dataset_required: bool = False
63+
compatible_before: Iterable[str | AlgorithmTag] = [
64+
"qkv_diffusers",
65+
"padding_pruning",
66+
]
67+
compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"]
68+
69+
def get_hyperparameters(self) -> list:
70+
"""
71+
Get the hyperparameters for the RingAttn.
72+
73+
Returns
74+
-------
75+
list
76+
A list of hyperparameters.
77+
"""
78+
return [
79+
Boolean(
80+
"convert_to_f32",
81+
default=True,
82+
meta=dict(desc="Allowing intermediate computations in the attention mechanism to be upcast to 32-bit."),
83+
),
84+
CategoricalHyperparameter(
85+
"rotate_method",
86+
default_value="ALL_TO_ALL",
87+
meta=dict(desc="The method to use for rotating the computations."),
88+
choices=["ALL_TO_ALL", "ALL_GATHER"],
89+
),
90+
]
91+
92+
def model_check_fn(self, model: Any) -> bool:
93+
"""
94+
Check if the model is supported by the RingAttn.
95+
96+
Parameters
97+
----------
98+
model : Any
99+
The model to check.
100+
101+
Returns
102+
-------
103+
bool
104+
True if the model is supported, False otherwise.
105+
"""
106+
if torch.cuda.device_count() < 2:
107+
raise ValueError("RingAttn requires at least 2 GPUs")
108+
109+
return hasattr(model, "transformer") and isinstance(
110+
model.transformer, (FluxTransformer2DModel, WanTransformer3DModel)
111+
)
112+
113+
def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
114+
115+
# configure the ring attention hyperparameters
116+
_cp_options = ring_attention._cp_options # type: ignore
117+
_cp_options.convert_to_f32 = smash_config["convert_to_f32"]
118+
_cp_options.enable_load_balance = False
119+
_cp_options.rotate_method = getattr(ring_attention._RotateMethod, smash_config["rotate_method"]) # type: ignore
120+
121+
wrap_pipeline_call(model, torch.cuda.device_count())
122+
123+
mesh = dist.init_device_mesh("cuda", (torch.cuda.device_count(),), mesh_dim_names=("ring_dim",))
124+
rank = dist.get_rank()
125+
world_size = torch.cuda.device_count()
126+
127+
if isinstance(model.transformer, FluxTransformer2DModel):
128+
wrap_flux2d_transformer_forward(
129+
model.transformer,
130+
world_size,
131+
smash_config._base_config,
132+
rank,
133+
mesh,
134+
cache_helper=getattr(model, "cache_helper", None),
135+
)
136+
elif isinstance(model.transformer, WanTransformer3DModel):
137+
wrap_wan3d_transformer_forward(model.transformer, world_size, smash_config._base_config, rank, mesh)
138+
else:
139+
raise ValueError(f"Unsupported transformer type: {type(model.transformer)}")
140+
141+
return model
142+
143+
def import_algorithm_packages(self) -> dict[str, Any]:
144+
"""
145+
Import the algorithm packages.
146+
147+
Returns
148+
-------
149+
dict[str, Any]
150+
The algorithm packages.
151+
"""
152+
# even though it is a torch import we isolate it, as experimental modules can often change the interface
153+
# we import the package even though we dont use it directly to make sure it is available
154+
# additionally, we can not pass it as module to the distributed setup (not picklable)
155+
# nor as a string (the import massively irritates torch.compile)
156+
# we import it on the top of the file if available
157+
import torch.distributed.tensor.experimental._attention as ring_attention # noqa: F401
158+
159+
return dict()
160+
161+
162+
def wrap_wan3d_transformer_forward(
163+
model: Any,
164+
world_size: int,
165+
smash_config: Union[SmashConfig, SmashConfigPrefixWrapper],
166+
rank: int,
167+
mesh: DeviceMesh,
168+
) -> Any:
169+
"""
170+
Wrap the transformer forward pass to chunk the inputs and intercept the torch attention function.
171+
172+
Parameters
173+
----------
174+
model : Any
175+
The transformer model to wrap.
176+
world_size : int
177+
The number of GPUs to distribute the model on.
178+
smash_config : SmashConfig
179+
The SmashConfig to use.
180+
rank : int
181+
The rank of the current process.
182+
mesh : DeviceMesh
183+
The mesh to use for the distributed attention.
184+
"""
185+
for i, block in enumerate(model.blocks):
186+
block_original = block.forward
187+
188+
@functools.wraps(block_original)
189+
def block_forward(
190+
self,
191+
hidden_states: torch.Tensor,
192+
encoder_hidden_states: torch.Tensor,
193+
temb: torch.Tensor,
194+
rotary_emb: torch.Tensor,
195+
_block_ref=block,
196+
_original_forward=block_original,
197+
_layer_id=i,
198+
_num_layers=len(model.blocks),
199+
) -> torch.Tensor:
200+
# on the first layer, we chunk the hidden states
201+
if _layer_id == 0:
202+
hidden_states = hidden_states.chunk(world_size, dim=-2)[rank]
203+
204+
rotary_emb = rotary_emb.chunk(world_size, dim=-2)[rank]
205+
206+
# Use compiled version if available, otherwise use original (not the wrapped one!)
207+
forward_to_call = getattr(_block_ref, "compiled_forward", _original_forward)
208+
209+
with RingDistributedContext(mesh, smash_config):
210+
hidden_states = forward_to_call(hidden_states, encoder_hidden_states, temb, rotary_emb)
211+
212+
# on the last layer, we sync back the hidden states
213+
if _layer_id == _num_layers - 1:
214+
return sync_tensor(hidden_states, dim=-2, group=dist.distributed_c10d._get_default_group())
215+
216+
return hidden_states
217+
218+
block.original_forward = block_original
219+
block.forward = block_forward.__get__(block) # type: ignore
220+
221+
222+
def wrap_pipeline_call(model: Any, world_size: int) -> Any:
223+
"""
224+
Wrap the model forward pass to set up a generator with rank-specific device.
225+
226+
Parameters
227+
----------
228+
model : Any
229+
The model to wrap.
230+
world_size : int
231+
The number of GPUs to distribute the model on.
232+
"""
233+
# Set up generator with rank-specific device, if it is not explicitly specified the different
234+
# processes might sample different seeds, we have to sync this
235+
original_forward = model.__call__
236+
237+
@functools.wraps(original_forward)
238+
def new_forward(
239+
*args,
240+
**kwargs,
241+
):
242+
rank = kwargs.pop("rank") if "rank" in kwargs else dist.get_rank()
243+
if "generator" not in kwargs:
244+
# if we distributed manually, we can not use "dist" to get the rank, in this case we pass the rank ourselves
245+
seed_t = torch.randint(0, torch.iinfo(torch.int64).max, [1], dtype=torch.int64, device=f"cuda:{rank}")
246+
seed_t = sync_tensor(seed_t, dim=0, group=None)
247+
seed_t = seed_t.chunk(world_size, dim=0)[0]
248+
seed = seed_t.item()
249+
seed -= torch.iinfo(torch.int64).min
250+
generator = torch.Generator(f"cuda:{rank}").manual_seed(seed)
251+
kwargs["generator"] = generator
252+
253+
return original_forward(*args, **kwargs)
254+
255+
model.__call__ = new_forward # type: ignore
256+
257+
258+
def wrap_flux2d_transformer_forward(
259+
model: Any,
260+
world_size: int,
261+
smash_config: Union[SmashConfig, SmashConfigPrefixWrapper],
262+
rank: int,
263+
mesh: DeviceMesh,
264+
cache_helper: Any | None = None,
265+
) -> Any:
266+
"""
267+
Wrap the transformer forward pass to chunk the inputs and intercept the torch attention function.
268+
269+
Parameters
270+
----------
271+
model : Any
272+
The transformer model to wrap.
273+
world_size : int
274+
The number of GPUs to distribute the model on.
275+
smash_config : SmashConfig
276+
The SmashConfig to use.
277+
rank : int
278+
The rank of the current process.
279+
mesh : DeviceMesh
280+
The mesh to use for the distributed attention.
281+
cache_helper : Any | None
282+
The cache helper if one is present in the pipe.
283+
"""
284+
original_forward = model.forward
285+
286+
@functools.wraps(original_forward)
287+
def new_forward(
288+
self,
289+
hidden_states: torch.Tensor,
290+
encoder_hidden_states: Optional[torch.Tensor] = None,
291+
img_ids: torch.Tensor | None = None,
292+
txt_ids: torch.Tensor | None = None,
293+
*args,
294+
**kwargs,
295+
):
296+
# split all input tensors along the sequence length dimension and get chunk for this process (rank)
297+
# we do the forward pass on two separate chunks and only "sync" when the attention is computed
298+
# for intuition: number of chunks = number of GPUs
299+
hidden_states = hidden_states.chunk(world_size, dim=1)[rank]
300+
encoder_hidden_states = (
301+
encoder_hidden_states.chunk(world_size, dim=1)[rank] if encoder_hidden_states is not None else None
302+
)
303+
img_ids = img_ids.chunk(world_size, dim=0)[rank] if img_ids is not None else None
304+
txt_ids = txt_ids.chunk(world_size, dim=0)[rank] if txt_ids is not None else None
305+
306+
# this context basically intercepts any call to F.scaled_dot_product_attention
307+
# and replaces it with the ring attention implementation
308+
with RingDistributedContext(mesh, smash_config):
309+
output = self.inner_forward(
310+
hidden_states,
311+
encoder_hidden_states,
312+
*args,
313+
img_ids=img_ids,
314+
txt_ids=txt_ids,
315+
**kwargs,
316+
)
317+
318+
# before we output the result, we attach the separate chunks together again
319+
sample = output[0]
320+
sample = sync_tensor(sample, dim=-2, group=dist.distributed_c10d._get_default_group())
321+
return (sample, *output[1:])
322+
323+
model.forward = new_forward.__get__(model) # type: ignore
324+
model.inner_forward = original_forward.__get__(model if cache_helper is None else cache_helper) # type: ignore
325+
326+
327+
def sync_tensor(tensor: torch.Tensor, dim: int, group: dist.ProcessGroup | None) -> torch.Tensor:
328+
"""
329+
Sync a tensor across a process group.
330+
331+
Parameters
332+
----------
333+
tensor : torch.Tensor
334+
The tensor to sync.
335+
dim : int
336+
The dimension to sync along.
337+
group : dist.ProcessGroup | None
338+
The process group to sync across.
339+
340+
Returns
341+
-------
342+
torch.Tensor
343+
The synced tensor.
344+
"""
345+
tensor = tensor.transpose(0, dim).contiguous()
346+
347+
if group is None:
348+
group = dist.distributed_c10d._get_default_group()
349+
350+
if isinstance(group, dist.ProcessGroup):
351+
pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group
352+
else:
353+
pg = group.get_group()
354+
355+
x_shape = tensor.shape
356+
tensor = tensor.flatten()
357+
x_numel = tensor.numel() # type: ignore
358+
tensor = dist._functional_collectives.all_gather_tensor(tensor, group=pg, gather_dim=0) # type: ignore
359+
if isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor):
360+
tensor.wait()
361+
x_shape = list(x_shape) # type: ignore
362+
x_shape[0] *= tensor.numel() // x_numel # type: ignore
363+
tensor = tensor.reshape(x_shape) # type: ignore
364+
tensor = tensor.transpose(0, dim)
365+
return tensor

0 commit comments

Comments
 (0)