Skip to content

Commit 005e51b

Browse files
committed
Merge branch 'feature/group-offload-pinning' of https://github.com/bconstantine/diffusers into feature/group-offload-pinning
2 parents 7a2f3f0 + 3ef894d commit 005e51b

19 files changed

+2102
-115
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ By selectively loading and unloading the models you need at a given stage and sh
237237

238238
Use [`~ModelMixin.set_attention_backend`] to switch to a more optimized attention backend. Refer to this [table](../optimization/attention_backends#available-backends) for a complete list of available backends.
239239

240+
Most attention backends are compatible with context parallelism. Open an [issue](https://github.com/huggingface/diffusers/issues/new) if a backend is not compatible.
241+
240242
### Ring Attention
241243

242244
Key (K) and value (V) representations communicate between devices using [Ring Attention](https://huggingface.co/papers/2310.01889). This ensures each split sees every other token's K/V. Each GPU computes attention for its local K/V and passes it to the next GPU in the ring. No single GPU holds the full sequence, which reduces communication latency.
@@ -245,40 +247,60 @@ Pass a [`ContextParallelConfig`] to the `parallel_config` argument of the transf
245247

246248
```py
247249
import torch
248-
from diffusers import AutoModel, QwenImagePipeline, ContextParallelConfig
249-
250-
try:
251-
torch.distributed.init_process_group("nccl")
252-
rank = torch.distributed.get_rank()
253-
device = torch.device("cuda", rank % torch.cuda.device_count())
250+
from torch import distributed as dist
251+
from diffusers import DiffusionPipeline, ContextParallelConfig
252+
253+
def setup_distributed():
254+
if not dist.is_initialized():
255+
dist.init_process_group(backend="nccl")
256+
rank = dist.get_rank()
257+
device = torch.device(f"cuda:{rank}")
254258
torch.cuda.set_device(device)
255-
256-
transformer = AutoModel.from_pretrained("Qwen/Qwen-Image", subfolder="transformer", torch_dtype=torch.bfloat16, parallel_config=ContextParallelConfig(ring_degree=2))
257-
pipeline = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", transformer=transformer, torch_dtype=torch.bfloat16, device_map="cuda")
258-
pipeline.transformer.set_attention_backend("flash")
259+
return device
260+
261+
def main():
262+
device = setup_distributed()
263+
world_size = dist.get_world_size()
264+
265+
pipeline = DiffusionPipeline.from_pretrained(
266+
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
267+
)
268+
pipeline.transformer.set_attention_backend("_native_cudnn")
269+
270+
cp_config = ContextParallelConfig(ring_degree=world_size)
271+
pipeline.transformer.enable_parallelism(config=cp_config)
259272

260273
prompt = """
261274
cinematic film still of a cat sipping a margarita in a pool in Palm Springs, California
262275
highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain
263276
"""
264-
277+
265278
# Must specify generator so all ranks start with same latents (or pass your own)
266279
generator = torch.Generator().manual_seed(42)
267-
image = pipeline(prompt, num_inference_steps=50, generator=generator).images[0]
268-
269-
if rank == 0:
270-
image.save("output.png")
271-
272-
except Exception as e:
273-
print(f"An error occurred: {e}")
274-
torch.distributed.breakpoint()
275-
raise
276-
277-
finally:
278-
if torch.distributed.is_initialized():
279-
torch.distributed.destroy_process_group()
280+
image = pipeline(
281+
prompt,
282+
guidance_scale=3.5,
283+
num_inference_steps=50,
284+
generator=generator,
285+
).images[0]
286+
287+
if dist.get_rank() == 0:
288+
image.save(f"output.png")
289+
290+
if dist.is_initialized():
291+
dist.destroy_process_group()
292+
293+
294+
if __name__ == "__main__":
295+
main()
280296
```
281297

298+
The script above needs to be run with a distributed launcher, such as [torchrun](https://docs.pytorch.org/docs/stable/elastic/run.html), that is compatible with PyTorch. `--nproc-per-node` is set to the number of GPUs available.
299+
300+
/```shell
301+
`torchrun --nproc-per-node 2 above_script.py`.
302+
/```
303+
282304
### Ulysses Attention
283305

284306
[Ulysses Attention](https://huggingface.co/papers/2309.14509) splits a sequence across GPUs and performs an *all-to-all* communication (every device sends/receives data to every other device). Each GPU ends up with all tokens for only a subset of attention heads. Each GPU computes attention locally on all tokens for its head, then performs another all-to-all to regroup results by tokens for the next layer.
@@ -288,5 +310,26 @@ finally:
288310
Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
289311

290312
```py
313+
# Depending on the number of GPUs available.
291314
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
315+
```
316+
317+
### parallel_config
318+
319+
Pass `parallel_config` during model initialization to enable context parallelism.
320+
321+
```py
322+
CKPT_ID = "black-forest-labs/FLUX.1-dev"
323+
324+
cp_config = ContextParallelConfig(ring_degree=2)
325+
transformer = AutoModel.from_pretrained(
326+
CKPT_ID,
327+
subfolder="transformer",
328+
torch_dtype=torch.bfloat16,
329+
parallel_config=cp_config
330+
)
331+
332+
pipeline = DiffusionPipeline.from_pretrained(
333+
CKPT_ID, transformer=transformer, torch_dtype=torch.bfloat16,
334+
).to(device)
292335
```

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,8 @@
404404
else:
405405
_import_structure["modular_pipelines"].extend(
406406
[
407+
"Flux2AutoBlocks",
408+
"Flux2ModularPipeline",
407409
"FluxAutoBlocks",
408410
"FluxKontextAutoBlocks",
409411
"FluxKontextModularPipeline",
@@ -1111,6 +1113,8 @@
11111113
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
11121114
else:
11131115
from .modular_pipelines import (
1116+
Flux2AutoBlocks,
1117+
Flux2ModularPipeline,
11141118
FluxAutoBlocks,
11151119
FluxKontextAutoBlocks,
11161120
FluxKontextModularPipeline,

src/diffusers/hooks/group_offloading.py

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ class GroupOffloadingConfig:
6060
offload_to_disk_path: Optional[str] = None
6161
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
6262
block_modules: Optional[List[str]] = None
63+
exclude_kwargs: Optional[List[str]] = None
64+
module_prefix: Optional[str] = ""
6365
pin_groups: Optional[Union[str, Callable]] = None
6466

6567

@@ -156,7 +158,7 @@ def _pinned_memory_tensors(self):
156158
finally:
157159
pinned_dict = None
158160

159-
def _transfer_tensor_to_device(self, tensor, source_tensor):
161+
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream=None):
160162
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
161163
if self.record_stream:
162164
tensor.data.record_stream(self._torch_accelerator_module.current_stream())
@@ -211,6 +213,7 @@ def _onload_from_memory(self):
211213
self.stream.synchronize()
212214

213215
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
216+
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
214217
with context:
215218
if self.stream is not None:
216219
with self._pinned_memory_tensors() as pinned_memory:
@@ -308,13 +311,16 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
308311
self.next_group.onload_()
309312

310313
should_synchronize = (
311-
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
314+
not self.group.onload_self
315+
and self.group.stream is not None
316+
and not should_onload_next_group
317+
and not self.group.record_stream
312318
)
313319
if should_synchronize:
314320
self.group.stream.synchronize()
315321

316322
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
317-
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
323+
kwargs = self._send_kwargs_to_device(kwargs)
318324
return args, kwargs
319325

320326
# If the current module is the onload_leader of the group, we onload the group if it is supposed
@@ -329,7 +335,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
329335
self.next_group.onload_()
330336

331337
should_synchronize = (
332-
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
338+
not self.group.onload_self
339+
and self.group.stream is not None
340+
and not should_onload_next_group
341+
and not self.group.record_stream
333342
)
334343
if should_synchronize:
335344
# If this group didn't onload itself, it means it was asynchronously onloaded by the
@@ -341,7 +350,7 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
341350
self.group.stream.synchronize()
342351

343352
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
344-
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
353+
kwargs = self._send_kwargs_to_device(kwargs)
345354
return args, kwargs
346355

347356
def post_forward(self, module: torch.nn.Module, output):
@@ -352,6 +361,28 @@ def post_forward(self, module: torch.nn.Module, output):
352361
self.group.offload_()
353362
return output
354363

364+
def _is_group_on_device(self) -> bool:
365+
tensors = []
366+
for group_module in self.group.modules:
367+
tensors.extend(list(group_module.parameters()))
368+
tensors.extend(list(group_module.buffers()))
369+
tensors.extend(self.group.parameters)
370+
tensors.extend(self.group.buffers)
371+
372+
return len(tensors) > 0 and all(t.device == self.group.onload_device for t in tensors)
373+
374+
def _send_kwargs_to_device(self, kwargs):
375+
exclude_kwargs = self.config.exclude_kwargs or []
376+
if exclude_kwargs:
377+
moved_kwargs = send_to_device(
378+
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
379+
self.group.onload_device,
380+
non_blocking=self.group.non_blocking,
381+
)
382+
kwargs.update(moved_kwargs)
383+
return kwargs
384+
return send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
385+
355386
def _is_group_on_device(self) -> bool:
356387
tensors = []
357388
for group_module in self.group.modules:
@@ -524,6 +555,17 @@ def pre_forward(self, module, *args, **kwargs):
524555
return args, kwargs
525556

526557

558+
def _normalize_pin_groups(pin_groups: Optional[Union[str, Callable]]) -> Optional[Union[str, Callable]]:
559+
if isinstance(pin_groups, str):
560+
normalized_pin_groups = pin_groups.lower()
561+
if normalized_pin_groups not in {"first_last", "all"}:
562+
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
563+
return normalized_pin_groups
564+
if pin_groups is not None and not callable(pin_groups):
565+
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
566+
return pin_groups
567+
568+
527569
def apply_group_offloading(
528570
module: torch.nn.Module,
529571
onload_device: Union[str, torch.device],
@@ -536,6 +578,7 @@ def apply_group_offloading(
536578
low_cpu_mem_usage: bool = False,
537579
offload_to_disk_path: Optional[str] = None,
538580
block_modules: Optional[List[str]] = None,
581+
exclude_kwargs: Optional[List[str]] = None,
539582
pin_groups: Optional[Union[str, Callable]] = None,
540583
) -> None:
541584
r"""
@@ -595,11 +638,15 @@ def apply_group_offloading(
595638
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
596639
the CPU memory is a bottleneck but may counteract the benefits of using streams.
597640
block_modules (`List[str]`, *optional*):
598-
List of module names that should be treated as blocks for offloading. If provided, only these modules
599-
will be considered for block-level offloading. If not provided, the default block detection logic will be used.
641+
List of module names that should be treated as blocks for offloading. If provided, only these modules will
642+
be considered for block-level offloading. If not provided, the default block detection logic will be used.
643+
exclude_kwargs (`List[str]`, *optional*):
644+
List of kwarg keys that should not be processed by `send_to_device`. This is useful for mutable state like
645+
caching lists that need to maintain their object identity across forward passes. If not provided, will be
646+
inferred from the module's `_skip_keys` attribute if it exists.
600647
pin_groups (`"first_last"` or `"all"` or `Callable`, *optional*, defaults to `None`):
601-
Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first
602-
and last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
648+
Optionally keeps selected groups on the onload device permanently. Use `"first_last"` to pin the first and
649+
last parameter-bearing groups, `"all"` to pin every parameter-bearing group, or pass a callable that
603650
receives a module (and optionally the module name and index) and returns `True` to pin that group.
604651
605652
Example:
@@ -640,19 +687,14 @@ def apply_group_offloading(
640687
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
641688
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
642689

643-
normalized_pin_groups = pin_groups
644-
if isinstance(pin_groups, str):
645-
normalized_pin_groups = pin_groups.lower()
646-
if normalized_pin_groups not in {"first_last", "all"}:
647-
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
648-
elif pin_groups is not None and not callable(pin_groups):
649-
raise ValueError("`pin_groups` must be one of `None`, 'first_last', 'all', or a callable.")
690+
pin_groups = _normalize_pin_groups(pin_groups)
691+
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
650692

651-
pin_groups = normalized_pin_groups
693+
if block_modules is None:
694+
block_modules = getattr(module, "_group_offload_block_modules", None)
652695

653-
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
654-
registry = HookRegistry.check_if_exists_or_initialize(module)
655-
registry._group_offload_pin_groups = pin_groups
696+
if exclude_kwargs is None:
697+
exclude_kwargs = getattr(module, "_skip_keys", None)
656698

657699
config = GroupOffloadingConfig(
658700
onload_device=onload_device,
@@ -665,6 +707,8 @@ def apply_group_offloading(
665707
low_cpu_mem_usage=low_cpu_mem_usage,
666708
offload_to_disk_path=offload_to_disk_path,
667709
block_modules=block_modules,
710+
exclude_kwargs=exclude_kwargs,
711+
module_prefix="",
668712
pin_groups=pin_groups,
669713
)
670714
_apply_group_offloading(module, config)
@@ -706,7 +750,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
706750

707751
for name, submodule in module.named_children():
708752
# Check if this is an explicitly defined block module
709-
if name in block_modules:
753+
if block_modules and name in block_modules:
710754
# Apply block offloading to the specified submodule
711755
_apply_block_offloading_to_submodule(
712756
submodule, name, config, modules_with_group_offloading, matched_module_groups
@@ -802,7 +846,7 @@ def _apply_block_offloading_to_submodule(
802846
if len(current_modules) == 0:
803847
continue
804848

805-
group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
849+
group_id = f"{config.module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
806850
group = ModuleGroup(
807851
modules=current_modules,
808852
offload_device=config.offload_device,
@@ -834,7 +878,7 @@ def _apply_block_offloading_to_submodule(
834878
record_stream=config.record_stream,
835879
low_cpu_mem_usage=config.low_cpu_mem_usage,
836880
onload_self=True,
837-
group_id=name,
881+
group_id=f"{config.module_prefix}{name}",
838882
)
839883
matched_module_groups.append(group)
840884
modules_with_group_offloading.add(name)
@@ -864,7 +908,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
864908
record_stream=config.record_stream,
865909
low_cpu_mem_usage=config.low_cpu_mem_usage,
866910
onload_self=True,
867-
group_id=name,
911+
group_id=f"{config.module_prefix}{name}",
868912
)
869913
_apply_group_offloading_hook(submodule, group, config=config)
870914
modules_with_group_offloading.add(name)
@@ -911,7 +955,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
911955
record_stream=config.record_stream,
912956
low_cpu_mem_usage=config.low_cpu_mem_usage,
913957
onload_self=True,
914-
group_id=name,
958+
group_id=f"{config.module_prefix}{name}",
915959
)
916960
_apply_group_offloading_hook(parent_module, group, config=config)
917961

@@ -966,8 +1010,8 @@ def _apply_lazy_group_offloading_hook(
9661010
if registry.get_hook(_GROUP_OFFLOADING) is None:
9671011
hook = GroupOffloadingHook(group, config=config)
9681012
registry.register_hook(hook, _GROUP_OFFLOADING)
969-
970-
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups = config.pin_groups)
1013+
1014+
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook(pin_groups=config.pin_groups)
9711015
registry.register_hook(lazy_prefetch_hook, _LAZY_PREFETCH_GROUP_OFFLOADING)
9721016

9731017

0 commit comments

Comments
 (0)