|
| 1 | +# Copyright 2025 - Pruna AI GmbH. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from __future__ import annotations |
| 16 | + |
| 17 | +import contextlib |
| 18 | +import functools |
| 19 | +from collections.abc import Iterable |
| 20 | +from types import ModuleType |
| 21 | +from typing import Any, List, Optional, Union |
| 22 | + |
| 23 | +import torch |
| 24 | +import torch.distributed as dist |
| 25 | +from ConfigSpace import CategoricalHyperparameter |
| 26 | +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel |
| 27 | +from diffusers.models.transformers.transformer_wan import WanTransformer3DModel |
| 28 | +from torch.distributed.tensor.device_mesh import DeviceMesh |
| 29 | + |
| 30 | +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase |
| 31 | +from pruna.algorithms.base.tags import AlgorithmTag |
| 32 | +from pruna.algorithms.ring_attn.utils.ring_utils import RingDistributedContext |
| 33 | +from pruna.config.hyperparameters import Boolean |
| 34 | +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper |
| 35 | +from pruna.engine.save import SAVE_FUNCTIONS |
| 36 | + |
| 37 | +ring_attention: ModuleType | None = None |
| 38 | + |
| 39 | +with contextlib.suppress(ImportError): |
| 40 | + # see "import_algorithm_packages" for further explanation |
| 41 | + import torch.distributed.tensor.experimental._attention as ring_attention |
| 42 | + |
| 43 | + |
| 44 | +class RingAttn(PrunaAlgorithmBase): |
| 45 | + """ |
| 46 | + Distributed attention on multiple GPUs computation by using the torch native ring attention implementation. |
| 47 | +
|
| 48 | + Each GPU stores only its own slice of Q/K/V and participates in a Ring Attention shuffle that lets every query |
| 49 | + attend to every key/value. The result is lower KV-cache/activation memory per GPU and higher arithmetic intensity. |
| 50 | + """ |
| 51 | + |
| 52 | + algorithm_name: str = "ring_attn" |
| 53 | + group_tags: list[AlgorithmTag] = [AlgorithmTag.KERNEL] |
| 54 | + save_fn = SAVE_FUNCTIONS.reapply |
| 55 | + references = { |
| 56 | + "Implementation": "https://docs.pytorch.org/tutorials/prototype/context_parallel.html", |
| 57 | + "Paper": "https://arxiv.org/pdf/2310.01889", |
| 58 | + } |
| 59 | + tokenizer_required: bool = False |
| 60 | + processor_required: bool = False |
| 61 | + runs_on: list[str] = ["cuda"] |
| 62 | + dataset_required: bool = False |
| 63 | + compatible_before: Iterable[str | AlgorithmTag] = [ |
| 64 | + "qkv_diffusers", |
| 65 | + "padding_pruning", |
| 66 | + ] |
| 67 | + compatible_after: Iterable[str | AlgorithmTag] = ["torch_compile"] |
| 68 | + |
| 69 | + def get_hyperparameters(self) -> list: |
| 70 | + """ |
| 71 | + Get the hyperparameters for the RingAttn. |
| 72 | +
|
| 73 | + Returns |
| 74 | + ------- |
| 75 | + list |
| 76 | + A list of hyperparameters. |
| 77 | + """ |
| 78 | + return [ |
| 79 | + Boolean( |
| 80 | + "convert_to_f32", |
| 81 | + default=True, |
| 82 | + meta=dict(desc="Allowing intermediate computations in the attention mechanism to be upcast to 32-bit."), |
| 83 | + ), |
| 84 | + CategoricalHyperparameter( |
| 85 | + "rotate_method", |
| 86 | + default_value="ALL_TO_ALL", |
| 87 | + meta=dict(desc="The method to use for rotating the computations."), |
| 88 | + choices=["ALL_TO_ALL", "ALL_GATHER"], |
| 89 | + ), |
| 90 | + ] |
| 91 | + |
| 92 | + def model_check_fn(self, model: Any) -> bool: |
| 93 | + """ |
| 94 | + Check if the model is supported by the RingAttn. |
| 95 | +
|
| 96 | + Parameters |
| 97 | + ---------- |
| 98 | + model : Any |
| 99 | + The model to check. |
| 100 | +
|
| 101 | + Returns |
| 102 | + ------- |
| 103 | + bool |
| 104 | + True if the model is supported, False otherwise. |
| 105 | + """ |
| 106 | + if torch.cuda.device_count() < 2: |
| 107 | + raise ValueError("RingAttn requires at least 2 GPUs") |
| 108 | + |
| 109 | + return hasattr(model, "transformer") and isinstance( |
| 110 | + model.transformer, (FluxTransformer2DModel, WanTransformer3DModel) |
| 111 | + ) |
| 112 | + |
| 113 | + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: |
| 114 | + |
| 115 | + # configure the ring attention hyperparameters |
| 116 | + _cp_options = ring_attention._cp_options # type: ignore |
| 117 | + _cp_options.convert_to_f32 = smash_config["convert_to_f32"] |
| 118 | + _cp_options.enable_load_balance = False |
| 119 | + _cp_options.rotate_method = getattr(ring_attention._RotateMethod, smash_config["rotate_method"]) # type: ignore |
| 120 | + |
| 121 | + wrap_pipeline_call(model, torch.cuda.device_count()) |
| 122 | + |
| 123 | + mesh = dist.init_device_mesh("cuda", (torch.cuda.device_count(),), mesh_dim_names=("ring_dim",)) |
| 124 | + rank = dist.get_rank() |
| 125 | + world_size = torch.cuda.device_count() |
| 126 | + |
| 127 | + if isinstance(model.transformer, FluxTransformer2DModel): |
| 128 | + wrap_flux2d_transformer_forward( |
| 129 | + model.transformer, |
| 130 | + world_size, |
| 131 | + smash_config._base_config, |
| 132 | + rank, |
| 133 | + mesh, |
| 134 | + cache_helper=getattr(model, "cache_helper", None), |
| 135 | + ) |
| 136 | + elif isinstance(model.transformer, WanTransformer3DModel): |
| 137 | + wrap_wan3d_transformer_forward(model.transformer, world_size, smash_config._base_config, rank, mesh) |
| 138 | + else: |
| 139 | + raise ValueError(f"Unsupported transformer type: {type(model.transformer)}") |
| 140 | + |
| 141 | + return model |
| 142 | + |
| 143 | + def import_algorithm_packages(self) -> dict[str, Any]: |
| 144 | + """ |
| 145 | + Import the algorithm packages. |
| 146 | +
|
| 147 | + Returns |
| 148 | + ------- |
| 149 | + dict[str, Any] |
| 150 | + The algorithm packages. |
| 151 | + """ |
| 152 | + # even though it is a torch import we isolate it, as experimental modules can often change the interface |
| 153 | + # we import the package even though we dont use it directly to make sure it is available |
| 154 | + # additionally, we can not pass it as module to the distributed setup (not picklable) |
| 155 | + # nor as a string (the import massively irritates torch.compile) |
| 156 | + # we import it on the top of the file if available |
| 157 | + import torch.distributed.tensor.experimental._attention as ring_attention # noqa: F401 |
| 158 | + |
| 159 | + return dict() |
| 160 | + |
| 161 | + |
| 162 | +def wrap_wan3d_transformer_forward( |
| 163 | + model: Any, |
| 164 | + world_size: int, |
| 165 | + smash_config: Union[SmashConfig, SmashConfigPrefixWrapper], |
| 166 | + rank: int, |
| 167 | + mesh: DeviceMesh, |
| 168 | +) -> Any: |
| 169 | + """ |
| 170 | + Wrap the transformer forward pass to chunk the inputs and intercept the torch attention function. |
| 171 | +
|
| 172 | + Parameters |
| 173 | + ---------- |
| 174 | + model : Any |
| 175 | + The transformer model to wrap. |
| 176 | + world_size : int |
| 177 | + The number of GPUs to distribute the model on. |
| 178 | + smash_config : SmashConfig |
| 179 | + The SmashConfig to use. |
| 180 | + rank : int |
| 181 | + The rank of the current process. |
| 182 | + mesh : DeviceMesh |
| 183 | + The mesh to use for the distributed attention. |
| 184 | + """ |
| 185 | + for i, block in enumerate(model.blocks): |
| 186 | + block_original = block.forward |
| 187 | + |
| 188 | + @functools.wraps(block_original) |
| 189 | + def block_forward( |
| 190 | + self, |
| 191 | + hidden_states: torch.Tensor, |
| 192 | + encoder_hidden_states: torch.Tensor, |
| 193 | + temb: torch.Tensor, |
| 194 | + rotary_emb: torch.Tensor, |
| 195 | + _block_ref=block, |
| 196 | + _original_forward=block_original, |
| 197 | + _layer_id=i, |
| 198 | + _num_layers=len(model.blocks), |
| 199 | + ) -> torch.Tensor: |
| 200 | + # on the first layer, we chunk the hidden states |
| 201 | + if _layer_id == 0: |
| 202 | + hidden_states = hidden_states.chunk(world_size, dim=-2)[rank] |
| 203 | + |
| 204 | + rotary_emb = rotary_emb.chunk(world_size, dim=-2)[rank] |
| 205 | + |
| 206 | + # Use compiled version if available, otherwise use original (not the wrapped one!) |
| 207 | + forward_to_call = getattr(_block_ref, "compiled_forward", _original_forward) |
| 208 | + |
| 209 | + with RingDistributedContext(mesh, smash_config): |
| 210 | + hidden_states = forward_to_call(hidden_states, encoder_hidden_states, temb, rotary_emb) |
| 211 | + |
| 212 | + # on the last layer, we sync back the hidden states |
| 213 | + if _layer_id == _num_layers - 1: |
| 214 | + return sync_tensor(hidden_states, dim=-2, group=dist.distributed_c10d._get_default_group()) |
| 215 | + |
| 216 | + return hidden_states |
| 217 | + |
| 218 | + block.original_forward = block_original |
| 219 | + block.forward = block_forward.__get__(block) # type: ignore |
| 220 | + |
| 221 | + |
| 222 | +def wrap_pipeline_call(model: Any, world_size: int) -> Any: |
| 223 | + """ |
| 224 | + Wrap the model forward pass to set up a generator with rank-specific device. |
| 225 | +
|
| 226 | + Parameters |
| 227 | + ---------- |
| 228 | + model : Any |
| 229 | + The model to wrap. |
| 230 | + world_size : int |
| 231 | + The number of GPUs to distribute the model on. |
| 232 | + """ |
| 233 | + # Set up generator with rank-specific device, if it is not explicitly specified the different |
| 234 | + # processes might sample different seeds, we have to sync this |
| 235 | + original_forward = model.__call__ |
| 236 | + |
| 237 | + @functools.wraps(original_forward) |
| 238 | + def new_forward( |
| 239 | + *args, |
| 240 | + **kwargs, |
| 241 | + ): |
| 242 | + rank = kwargs.pop("rank") if "rank" in kwargs else dist.get_rank() |
| 243 | + if "generator" not in kwargs: |
| 244 | + # if we distributed manually, we can not use "dist" to get the rank, in this case we pass the rank ourselves |
| 245 | + seed_t = torch.randint(0, torch.iinfo(torch.int64).max, [1], dtype=torch.int64, device=f"cuda:{rank}") |
| 246 | + seed_t = sync_tensor(seed_t, dim=0, group=None) |
| 247 | + seed_t = seed_t.chunk(world_size, dim=0)[0] |
| 248 | + seed = seed_t.item() |
| 249 | + seed -= torch.iinfo(torch.int64).min |
| 250 | + generator = torch.Generator(f"cuda:{rank}").manual_seed(seed) |
| 251 | + kwargs["generator"] = generator |
| 252 | + |
| 253 | + return original_forward(*args, **kwargs) |
| 254 | + |
| 255 | + model.__call__ = new_forward # type: ignore |
| 256 | + |
| 257 | + |
| 258 | +def wrap_flux2d_transformer_forward( |
| 259 | + model: Any, |
| 260 | + world_size: int, |
| 261 | + smash_config: Union[SmashConfig, SmashConfigPrefixWrapper], |
| 262 | + rank: int, |
| 263 | + mesh: DeviceMesh, |
| 264 | + cache_helper: Any | None = None, |
| 265 | +) -> Any: |
| 266 | + """ |
| 267 | + Wrap the transformer forward pass to chunk the inputs and intercept the torch attention function. |
| 268 | +
|
| 269 | + Parameters |
| 270 | + ---------- |
| 271 | + model : Any |
| 272 | + The transformer model to wrap. |
| 273 | + world_size : int |
| 274 | + The number of GPUs to distribute the model on. |
| 275 | + smash_config : SmashConfig |
| 276 | + The SmashConfig to use. |
| 277 | + rank : int |
| 278 | + The rank of the current process. |
| 279 | + mesh : DeviceMesh |
| 280 | + The mesh to use for the distributed attention. |
| 281 | + cache_helper : Any | None |
| 282 | + The cache helper if one is present in the pipe. |
| 283 | + """ |
| 284 | + original_forward = model.forward |
| 285 | + |
| 286 | + @functools.wraps(original_forward) |
| 287 | + def new_forward( |
| 288 | + self, |
| 289 | + hidden_states: torch.Tensor, |
| 290 | + encoder_hidden_states: Optional[torch.Tensor] = None, |
| 291 | + img_ids: torch.Tensor | None = None, |
| 292 | + txt_ids: torch.Tensor | None = None, |
| 293 | + *args, |
| 294 | + **kwargs, |
| 295 | + ): |
| 296 | + # split all input tensors along the sequence length dimension and get chunk for this process (rank) |
| 297 | + # we do the forward pass on two separate chunks and only "sync" when the attention is computed |
| 298 | + # for intuition: number of chunks = number of GPUs |
| 299 | + hidden_states = hidden_states.chunk(world_size, dim=1)[rank] |
| 300 | + encoder_hidden_states = ( |
| 301 | + encoder_hidden_states.chunk(world_size, dim=1)[rank] if encoder_hidden_states is not None else None |
| 302 | + ) |
| 303 | + img_ids = img_ids.chunk(world_size, dim=0)[rank] if img_ids is not None else None |
| 304 | + txt_ids = txt_ids.chunk(world_size, dim=0)[rank] if txt_ids is not None else None |
| 305 | + |
| 306 | + # this context basically intercepts any call to F.scaled_dot_product_attention |
| 307 | + # and replaces it with the ring attention implementation |
| 308 | + with RingDistributedContext(mesh, smash_config): |
| 309 | + output = self.inner_forward( |
| 310 | + hidden_states, |
| 311 | + encoder_hidden_states, |
| 312 | + *args, |
| 313 | + img_ids=img_ids, |
| 314 | + txt_ids=txt_ids, |
| 315 | + **kwargs, |
| 316 | + ) |
| 317 | + |
| 318 | + # before we output the result, we attach the separate chunks together again |
| 319 | + sample = output[0] |
| 320 | + sample = sync_tensor(sample, dim=-2, group=dist.distributed_c10d._get_default_group()) |
| 321 | + return (sample, *output[1:]) |
| 322 | + |
| 323 | + model.forward = new_forward.__get__(model) # type: ignore |
| 324 | + model.inner_forward = original_forward.__get__(model if cache_helper is None else cache_helper) # type: ignore |
| 325 | + |
| 326 | + |
| 327 | +def sync_tensor(tensor: torch.Tensor, dim: int, group: dist.ProcessGroup | None) -> torch.Tensor: |
| 328 | + """ |
| 329 | + Sync a tensor across a process group. |
| 330 | +
|
| 331 | + Parameters |
| 332 | + ---------- |
| 333 | + tensor : torch.Tensor |
| 334 | + The tensor to sync. |
| 335 | + dim : int |
| 336 | + The dimension to sync along. |
| 337 | + group : dist.ProcessGroup | None |
| 338 | + The process group to sync across. |
| 339 | +
|
| 340 | + Returns |
| 341 | + ------- |
| 342 | + torch.Tensor |
| 343 | + The synced tensor. |
| 344 | + """ |
| 345 | + tensor = tensor.transpose(0, dim).contiguous() |
| 346 | + |
| 347 | + if group is None: |
| 348 | + group = dist.distributed_c10d._get_default_group() |
| 349 | + |
| 350 | + if isinstance(group, dist.ProcessGroup): |
| 351 | + pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = group |
| 352 | + else: |
| 353 | + pg = group.get_group() |
| 354 | + |
| 355 | + x_shape = tensor.shape |
| 356 | + tensor = tensor.flatten() |
| 357 | + x_numel = tensor.numel() # type: ignore |
| 358 | + tensor = dist._functional_collectives.all_gather_tensor(tensor, group=pg, gather_dim=0) # type: ignore |
| 359 | + if isinstance(tensor, dist._functional_collectives.AsyncCollectiveTensor): |
| 360 | + tensor.wait() |
| 361 | + x_shape = list(x_shape) # type: ignore |
| 362 | + x_shape[0] *= tensor.numel() // x_numel # type: ignore |
| 363 | + tensor = tensor.reshape(x_shape) # type: ignore |
| 364 | + tensor = tensor.transpose(0, dim) |
| 365 | + return tensor |
0 commit comments