Skip to content

Commit 8f10d05

Browse files
committed
improve tests; add docs
1 parent 24f9273 commit 8f10d05

File tree

6 files changed

+347
-45
lines changed

6 files changed

+347
-45
lines changed

docs/source/en/optimization/memory.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,46 @@ In order to properly offload models after they're called, it is required to run
158158

159159
</Tip>
160160

161+
## Group offloading
162+
163+
Group offloading is a middle ground between the two above methods. It works by offloading groups of internal layers (either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method is more memory-efficient than model-level offloading. It is also faster than sequential-level offloading, as the number of device synchronizations is reduced.
164+
165+
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to overlap data transfer and computation to reduce the overall execution time. This is enabled using layer prefetching with CUDA streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made much faster when using streams.
166+
167+
To enable group offloading, either call the [`~ModelMixin.enable_group_offloading`] method on the model or pass use [`~hooks.group_offloading.apply_group_offloading`]:
168+
169+
```python
170+
import torch
171+
from diffusers import CogVideoXPipeline
172+
from diffusers.hooks import apply_group_offloading
173+
from diffusers.utils import export_to_video
174+
175+
# Load the pipeline
176+
onload_device = torch.device("cuda")
177+
offload_device = torch.device("cpu")
178+
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
179+
180+
# We can utilize the enable_group_offloading method for Diffusers model implementations
181+
pipe.transformer.enable_group_offloading(onload_device=onload_device, offload_device=offload_device, offload_type="leaf_level", use_stream=True)
182+
183+
# For any other model implementations, the apply_group_offloading function can be used
184+
apply_group_offloading(pipe.text_encoder, onload_device=onload_device, offload_type="block_level", num_blocks_per_group=2)
185+
apply_group_offloading(pipe.vae, onload_device=onload_device, offload_type="leaf_level")
186+
187+
prompt = (
188+
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
189+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
190+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
191+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
192+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
193+
"atmosphere of this unique musical performance."
194+
)
195+
video = pipe(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
196+
# This utilized about 14.79 GB. It can be further reduced by using tiling and using leaf_level offloading throughout the pipeline.
197+
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
198+
export_to_video(video, "output.mp4", fps=8)
199+
```
200+
161201
## FP8 layerwise weight-casting
162202

163203
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.

src/diffusers/hooks/group_offloading.py

Lines changed: 81 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
3838
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
3939
torch.nn.Linear,
40-
torch.nn.LayerNorm, torch.nn.GroupNorm,
40+
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
41+
# because of double invocation of the same norm layer in CogVideoXLayerNorm
4142
)
4243
# fmt: on
4344

@@ -120,15 +121,13 @@ class GroupOffloadingHook(ModelHook):
120121
def __init__(
121122
self,
122123
group: ModuleGroup,
123-
offload_on_init: bool = True,
124124
next_group: Optional[ModuleGroup] = None,
125125
) -> None:
126126
self.group = group
127-
self.offload_on_init = offload_on_init
128127
self.next_group = next_group
129128

130129
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
131-
if self.offload_on_init and self.group.offload_leader == module:
130+
if self.group.offload_leader == module:
132131
self.group.offload_()
133132
return module
134133

@@ -262,14 +261,78 @@ def pre_forward(self, module, *args, **kwargs):
262261

263262
def apply_group_offloading(
264263
module: torch.nn.Module,
264+
onload_device: torch.device,
265+
offload_device: torch.device = torch.device("cpu"),
265266
offload_type: str = "block_level",
266267
num_blocks_per_group: Optional[int] = None,
267-
offload_device: torch.device = torch.device("cpu"),
268-
onload_device: torch.device = torch.device("cuda"),
269-
force_offload: bool = True,
270268
non_blocking: bool = False,
271269
use_stream: bool = False,
272270
) -> None:
271+
r"""
272+
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
273+
where it is beneficial, we need to first provide some context on how other supported offloading methods work.
274+
275+
Typically, offloading is done at two levels:
276+
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
277+
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator device
278+
when needed for computation. This method is more memory-efficient than keeping all components on the accelerator,
279+
but the memory requirements are still quite high. For this method to work, one needs memory equivalent to size of
280+
the model in runtime dtype + size of largest intermediate activation tensors to be able to complete the forward
281+
pass.
282+
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method. It
283+
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
284+
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
285+
memory, but can be slower due to the excessive number of device synchronizations.
286+
287+
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
288+
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method is more memory-efficient than module-level
289+
offloading. It is also faster than leaf-level offloading, as the number of device synchronizations is reduced.
290+
291+
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability to
292+
overlap data transfer and computation to reduce the overall execution time. This is enabled using layer prefetching
293+
with streams, i.e., the layer that is to be executed next starts onloading to the accelerator device while the
294+
current layer is being executed - this increases the memory requirements slightly. Note that this implementation
295+
also supports leaf-level offloading but can be made much faster when using streams.
296+
297+
Args:
298+
module (`torch.nn.Module`):
299+
The module to which group offloading is applied.
300+
onload_device (`torch.device`):
301+
The device to which the group of modules are onloaded.
302+
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
303+
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
304+
offload_type (`str`, defaults to "block_level"):
305+
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
306+
"block_level".
307+
num_blocks_per_group (`int`, *optional*):
308+
The number of blocks per group when using offload_type="block_level". This is required when using
309+
offload_type="block_level".
310+
non_blocking (`bool`, defaults to `False`):
311+
If True, offloading and onloading is done with non-blocking data transfer.
312+
use_stream (`bool`, defaults to `False`):
313+
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
314+
overlapping computation and data transfer.
315+
316+
Example:
317+
```python
318+
>>> from diffusers import CogVideoXTransformer3DModel
319+
>>> from diffusers.hooks import apply_group_offloading
320+
321+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
322+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
323+
... )
324+
325+
>>> apply_group_offloading(
326+
... transformer,
327+
... onload_device=torch.device("cuda"),
328+
... offload_device=torch.device("cpu"),
329+
... offload_type="block_level",
330+
... num_blocks_per_group=2,
331+
... use_stream=True,
332+
... )
333+
```
334+
"""
335+
273336
stream = None
274337
if use_stream:
275338
if torch.cuda.is_available():
@@ -279,15 +342,13 @@ def apply_group_offloading(
279342

280343
if offload_type == "block_level":
281344
if num_blocks_per_group is None:
282-
raise ValueError("num_blocks_per_group must be provided when using offload_group_patterns='block_level'.")
345+
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
283346

284347
_apply_group_offloading_block_level(
285-
module, num_blocks_per_group, offload_device, onload_device, force_offload, non_blocking, stream=stream
348+
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
286349
)
287350
elif offload_type == "leaf_level":
288-
_apply_group_offloading_leaf_level(
289-
module, offload_device, onload_device, force_offload, non_blocking, stream=stream
290-
)
351+
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
291352
else:
292353
raise ValueError(f"Unsupported offload_type: {offload_type}")
293354

@@ -297,7 +358,6 @@ def _apply_group_offloading_block_level(
297358
num_blocks_per_group: int,
298359
offload_device: torch.device,
299360
onload_device: torch.device,
300-
force_offload: bool,
301361
non_blocking: bool,
302362
stream: Optional[torch.cuda.Stream] = None,
303363
) -> None:
@@ -312,9 +372,6 @@ def _apply_group_offloading_block_level(
312372
The device to which the group of modules are offloaded. This should typically be the CPU.
313373
onload_device (`torch.device`):
314374
The device to which the group of modules are onloaded.
315-
force_offload (`bool`):
316-
If True, all module groups are offloaded to the offload_device. If False, only layers that match
317-
`offload_group_patterns` are offloaded to the offload_device.
318375
non_blocking (`bool`):
319376
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
320377
and data transfer.
@@ -362,10 +419,9 @@ def _apply_group_offloading_block_level(
362419
next_group = (
363420
matched_module_groups[i + 1] if i + 1 < len(matched_module_groups) and stream is not None else None
364421
)
365-
should_offload = force_offload or i > 0
366422

367423
for group_module in group.modules:
368-
_apply_group_offloading_hook(group_module, group, should_offload, next_group)
424+
_apply_group_offloading_hook(group_module, group, next_group)
369425

370426
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
371427
# when the forward pass of this module is called. This is because the top-level module is not
@@ -392,14 +448,13 @@ def _apply_group_offloading_block_level(
392448
onload_self=True,
393449
)
394450
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
395-
_apply_group_offloading_hook(module, unmatched_group, force_offload, next_group)
451+
_apply_group_offloading_hook(module, unmatched_group, next_group)
396452

397453

398454
def _apply_group_offloading_leaf_level(
399455
module: torch.nn.Module,
400456
offload_device: torch.device,
401457
onload_device: torch.device,
402-
force_offload: bool,
403458
non_blocking: bool,
404459
stream: Optional[torch.cuda.Stream] = None,
405460
) -> None:
@@ -416,9 +471,6 @@ def _apply_group_offloading_leaf_level(
416471
The device to which the group of modules are offloaded. This should typically be the CPU.
417472
onload_device (`torch.device`):
418473
The device to which the group of modules are onloaded.
419-
force_offload (`bool`):
420-
If True, all module groups are offloaded to the offload_device. If False, only layers that match
421-
`offload_group_patterns` are offloaded to the offload_device.
422474
non_blocking (`bool`):
423475
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
424476
and data transfer.
@@ -450,7 +502,7 @@ def _apply_group_offloading_leaf_level(
450502
cpu_param_dict=cpu_param_dict,
451503
onload_self=True,
452504
)
453-
_apply_group_offloading_hook(submodule, group, True, None)
505+
_apply_group_offloading_hook(submodule, group, None)
454506
modules_with_group_offloading.add(name)
455507

456508
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -495,7 +547,7 @@ def _apply_group_offloading_leaf_level(
495547
cpu_param_dict=cpu_param_dict,
496548
onload_self=True,
497549
)
498-
_apply_group_offloading_hook(parent_module, group, True, None)
550+
_apply_group_offloading_hook(parent_module, group, None)
499551

500552
# This is a dummy group that will handle lazy prefetching from the top-level module to the first leaf module
501553
unmatched_group = ModuleGroup(
@@ -516,38 +568,36 @@ def _apply_group_offloading_leaf_level(
516568
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
517569
# execution order and apply prefetching in the correct order.
518570
if stream is None:
519-
_apply_group_offloading_hook(module, unmatched_group, force_offload, None)
571+
_apply_group_offloading_hook(module, unmatched_group, None)
520572
else:
521-
_apply_lazy_group_offloading_hook(module, unmatched_group, force_offload, None)
573+
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
522574

523575

524576
def _apply_group_offloading_hook(
525577
module: torch.nn.Module,
526578
group: ModuleGroup,
527-
offload_on_init: bool,
528579
next_group: Optional[ModuleGroup] = None,
529580
) -> None:
530581
registry = HookRegistry.check_if_exists_or_initialize(module)
531582

532583
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
533584
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
534585
if registry.get_hook(_GROUP_OFFLOADING) is None:
535-
hook = GroupOffloadingHook(group, offload_on_init, next_group)
586+
hook = GroupOffloadingHook(group, next_group)
536587
registry.register_hook(hook, _GROUP_OFFLOADING)
537588

538589

539590
def _apply_lazy_group_offloading_hook(
540591
module: torch.nn.Module,
541592
group: ModuleGroup,
542-
offload_on_init: bool,
543593
next_group: Optional[ModuleGroup] = None,
544594
) -> None:
545595
registry = HookRegistry.check_if_exists_or_initialize(module)
546596

547597
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
548598
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
549599
if registry.get_hook(_GROUP_OFFLOADING) is None:
550-
hook = GroupOffloadingHook(group, offload_on_init, next_group)
600+
hook = GroupOffloadingHook(group, next_group)
551601
registry.register_hook(hook, _GROUP_OFFLOADING)
552602

553603
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
@@ -561,14 +611,12 @@ def _gather_parameters_with_no_group_offloading_parent(
561611
for name, parameter in module.named_parameters():
562612
has_parent_with_group_offloading = False
563613
atoms = name.split(".")
564-
565614
while len(atoms) > 0:
566615
parent_name = ".".join(atoms)
567616
if parent_name in modules_with_group_offloading:
568617
has_parent_with_group_offloading = True
569618
break
570619
atoms.pop()
571-
572620
if not has_parent_with_group_offloading:
573621
parameters.append((name, parameter))
574622
return parameters
@@ -581,14 +629,12 @@ def _gather_buffers_with_no_group_offloading_parent(
581629
for name, buffer in module.named_buffers():
582630
has_parent_with_group_offloading = False
583631
atoms = name.split(".")
584-
585632
while len(atoms) > 0:
586633
parent_name = ".".join(atoms)
587634
if parent_name in modules_with_group_offloading:
588635
has_parent_with_group_offloading = True
589636
break
590637
atoms.pop()
591-
592638
if not has_parent_with_group_offloading:
593639
buffers.append((name, buffer))
594640
return buffers

src/diffusers/models/modeling_utils.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from torch import Tensor, nn
3434

3535
from .. import __version__
36-
from ..hooks import apply_layerwise_casting
36+
from ..hooks import apply_group_offloading, apply_layerwise_casting
3737
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
3838
from ..quantizers.quantization_config import QuantizationMethod
3939
from ..utils import (
@@ -446,6 +446,55 @@ def enable_layerwise_casting(
446446
self, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
447447
)
448448

449+
def enable_group_offloading(
450+
self,
451+
onload_device: torch.device,
452+
offload_device: torch.device = torch.device("cpu"),
453+
offload_type: str = "block_level",
454+
num_blocks_per_group: Optional[int] = None,
455+
non_blocking: bool = False,
456+
use_stream: bool = False,
457+
) -> None:
458+
r"""
459+
Activates group offloading for the current model.
460+
461+
See [`~hooks.group_offloading.apply_group_offloading`] for more information.
462+
463+
Example:
464+
465+
```python
466+
>>> from diffusers import CogVideoXTransformer3DModel
467+
468+
>>> transformer = CogVideoXTransformer3DModel.from_pretrained(
469+
... "THUDM/CogVideoX-5b", subfolder="transformer", torch_dtype=torch.bfloat16
470+
... )
471+
472+
>>> transformer.enable_group_offloading(
473+
... onload_device=torch.device("cuda"),
474+
... offload_device=torch.device("cpu"),
475+
... offload_type="leaf_level",
476+
... use_stream=True,
477+
... )
478+
```
479+
"""
480+
if getattr(self, "enable_tiling", None) is not None and getattr(self, "use_tiling", False) and use_stream:
481+
msg = (
482+
"Applying group offloading on autoencoders, with CUDA streams, may not work as expected if the first "
483+
"forward pass is executed with tiling enabled. Please make sure to either:\n"
484+
"1. Run a forward pass with small input shapes.\n"
485+
"2. Or, run a forward pass with tiling disabled (can still use small dummy inputs)."
486+
)
487+
logger.warning(msg)
488+
if not self._supports_group_offloading:
489+
raise ValueError(
490+
f"{self.__class__.__name__} does not support group offloading. Please make sure to set the boolean attribute "
491+
f"`_supports_group_offloading` to `True` in the class definition. If you believe this is a mistake, please "
492+
f"open an issue at https://github.com/huggingface/diffusers/issues."
493+
)
494+
apply_group_offloading(
495+
self, onload_device, offload_device, offload_type, num_blocks_per_group, non_blocking, use_stream
496+
)
497+
449498
def save_pretrained(
450499
self,
451500
save_directory: Union[str, os.PathLike],

0 commit comments

Comments
 (0)