Skip to content

Commit 4345907

Browse files
sayakpaulstevhliu
andauthored
[core] feat: support group offloading at the pipeline level (huggingface#12283)
* feat: support group offloading at the pipeline level. * add tests * up * [docs] Pipeline group offloading (huggingface#12286) init Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Steven Liu <[email protected]>
1 parent 4067d6c commit 4345907

File tree

3 files changed

+241
-3
lines changed

3 files changed

+241
-3
lines changed

docs/source/en/optimization/memory.md

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,53 @@ Group offloading moves groups of internal layers ([torch.nn.ModuleList](https://
291291
> [!WARNING]
292292
> Group offloading may not work with all models if the forward implementation contains weight-dependent device casting of inputs because it may clash with group offloading's device casting mechanism.
293293
294-
Call [`~ModelMixin.enable_group_offload`] to enable it for standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
295-
296-
The `offload_type` parameter can be set to `block_level` or `leaf_level`.
294+
Enable group offloading by configuring the `offload_type` parameter to `block_level` or `leaf_level`.
297295

298296
- `block_level` offloads groups of layers based on the `num_blocks_per_group` parameter. For example, if `num_blocks_per_group=2` on a model with 40 layers, 2 layers are onloaded and offloaded at a time (20 total onloads/offloads). This drastically reduces memory requirements.
299297
- `leaf_level` offloads individual layers at the lowest level and is equivalent to [CPU offloading](#cpu-offloading). But it can be made faster if you use streams without giving up inference speed.
300298

299+
Group offloading is supported for entire pipelines or individual models. Applying group offloading to the entire pipeline is the easiest option while selectively applying it to individual models gives users more flexibility to use different offloading techniques for different models.
300+
301+
<hfoptions id="group-offloading">
302+
<hfoption id="pipeline">
303+
304+
Call [`~DiffusionPipeline.enable_group_offload`] on a pipeline.
305+
306+
```py
307+
import torch
308+
from diffusers import CogVideoXPipeline
309+
from diffusers.hooks import apply_group_offloading
310+
from diffusers.utils import export_to_video
311+
312+
onload_device = torch.device("cuda")
313+
offload_device = torch.device("cpu")
314+
315+
pipeline = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
316+
pipeline.enable_group_offload(
317+
onload_device=onload_device,
318+
offload_device=offload_device,
319+
offload_type="leaf_level",
320+
use_stream=True
321+
)
322+
323+
prompt = (
324+
"A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
325+
"The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
326+
"pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
327+
"casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
328+
"The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
329+
"atmosphere of this unique musical performance."
330+
)
331+
video = pipeline(prompt=prompt, guidance_scale=6, num_inference_steps=50).frames[0]
332+
print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
333+
export_to_video(video, "output.mp4", fps=8)
334+
```
335+
336+
</hfoption>
337+
<hfoption id="model">
338+
339+
Call [`~ModelMixin.enable_group_offload`] on standard Diffusers model components that inherit from [`ModelMixin`]. For other model components that don't inherit from [`ModelMixin`], such as a generic [torch.nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html), use [`~hooks.apply_group_offloading`] instead.
340+
301341
```py
302342
import torch
303343
from diffusers import CogVideoXPipeline
@@ -328,6 +368,9 @@ print(f"Max memory reserved: {torch.cuda.max_memory_allocated() / 1024**3:.2f} G
328368
export_to_video(video, "output.mp4", fps=8)
329369
```
330370

371+
</hfoption>
372+
</hfoptions>
373+
331374
#### CUDA stream
332375

333376
The `use_stream` parameter can be activated for CUDA devices that support asynchronous data transfer streams to reduce overall execution time compared to [CPU offloading](#cpu-offloading). It overlaps data transfer and computation by using layer prefetching. The next layer to be executed is loaded onto the GPU while the current layer is still being executed. It can increase CPU memory significantly so ensure you have 2x the amount of memory as the model size.

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,133 @@ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Un
13341334
offload_buffers = len(model._parameters) > 0
13351335
cpu_offload(model, device, offload_buffers=offload_buffers)
13361336

1337+
def enable_group_offload(
1338+
self,
1339+
onload_device: torch.device,
1340+
offload_device: torch.device = torch.device("cpu"),
1341+
offload_type: str = "block_level",
1342+
num_blocks_per_group: Optional[int] = None,
1343+
non_blocking: bool = False,
1344+
use_stream: bool = False,
1345+
record_stream: bool = False,
1346+
low_cpu_mem_usage=False,
1347+
offload_to_disk_path: Optional[str] = None,
1348+
exclude_modules: Optional[Union[str, List[str]]] = None,
1349+
) -> None:
1350+
r"""
1351+
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is,
1352+
and where it is beneficial, we need to first provide some context on how other supported offloading methods
1353+
work.
1354+
1355+
Typically, offloading is done at two levels:
1356+
- Module-level: In Diffusers, this can be enabled using the `ModelMixin::enable_model_cpu_offload()` method. It
1357+
works by offloading each component of a pipeline to the CPU for storage, and onloading to the accelerator
1358+
device when needed for computation. This method is more memory-efficient than keeping all components on the
1359+
accelerator, but the memory requirements are still quite high. For this method to work, one needs memory
1360+
equivalent to size of the model in runtime dtype + size of largest intermediate activation tensors to be able
1361+
to complete the forward pass.
1362+
- Leaf-level: In Diffusers, this can be enabled using the `ModelMixin::enable_sequential_cpu_offload()` method.
1363+
It
1364+
works by offloading the lowest leaf-level parameters of the computation graph to the CPU for storage, and
1365+
onloading only the leafs to the accelerator device for computation. This uses the lowest amount of accelerator
1366+
memory, but can be slower due to the excessive number of device synchronizations.
1367+
1368+
Group offloading is a middle ground between the two methods. It works by offloading groups of internal layers,
1369+
(either `torch.nn.ModuleList` or `torch.nn.Sequential`). This method uses lower memory than module-level
1370+
offloading. It is also faster than leaf-level/sequential offloading, as the number of device synchronizations
1371+
is reduced.
1372+
1373+
Another supported feature (for CUDA devices with support for asynchronous data transfer streams) is the ability
1374+
to overlap data transfer and computation to reduce the overall execution time compared to sequential
1375+
offloading. This is enabled using layer prefetching with streams, i.e., the layer that is to be executed next
1376+
starts onloading to the accelerator device while the current layer is being executed - this increases the
1377+
memory requirements slightly. Note that this implementation also supports leaf-level offloading but can be made
1378+
much faster when using streams.
1379+
1380+
Args:
1381+
onload_device (`torch.device`):
1382+
The device to which the group of modules are onloaded.
1383+
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
1384+
The device to which the group of modules are offloaded. This should typically be the CPU. Default is
1385+
CPU.
1386+
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
1387+
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
1388+
"block_level".
1389+
offload_to_disk_path (`str`, *optional*, defaults to `None`):
1390+
The path to the directory where parameters will be offloaded. Setting this option can be useful in
1391+
limited RAM environment settings where a reasonable speed-memory trade-off is desired.
1392+
num_blocks_per_group (`int`, *optional*):
1393+
The number of blocks per group when using offload_type="block_level". This is required when using
1394+
offload_type="block_level".
1395+
non_blocking (`bool`, defaults to `False`):
1396+
If True, offloading and onloading is done with non-blocking data transfer.
1397+
use_stream (`bool`, defaults to `False`):
1398+
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
1399+
overlapping computation and data transfer.
1400+
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
1401+
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to
1402+
the [PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html)
1403+
more details.
1404+
low_cpu_mem_usage (`bool`, defaults to `False`):
1405+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them.
1406+
This option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be
1407+
useful when the CPU memory is a bottleneck but may counteract the benefits of using streams.
1408+
exclude_modules (`Union[str, List[str]]`, defaults to `None`): List of modules to exclude from offloading.
1409+
1410+
Example:
1411+
```python
1412+
>>> from diffusers import DiffusionPipeline
1413+
>>> import torch
1414+
1415+
>>> pipe = DiffusionPipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16)
1416+
1417+
>>> pipe.enable_group_offload(
1418+
... onload_device=torch.device("cuda"),
1419+
... offload_device=torch.device("cpu"),
1420+
... offload_type="leaf_level",
1421+
... use_stream=True,
1422+
... )
1423+
>>> image = pipe("a beautiful sunset").images[0]
1424+
```
1425+
"""
1426+
from ..hooks import apply_group_offloading
1427+
1428+
if isinstance(exclude_modules, str):
1429+
exclude_modules = [exclude_modules]
1430+
elif exclude_modules is None:
1431+
exclude_modules = []
1432+
1433+
unknown = set(exclude_modules) - self.components.keys()
1434+
if unknown:
1435+
logger.info(
1436+
f"The following modules are not present in pipeline: {', '.join(unknown)}. Ignore if this is expected."
1437+
)
1438+
1439+
group_offload_kwargs = {
1440+
"onload_device": onload_device,
1441+
"offload_device": offload_device,
1442+
"offload_type": offload_type,
1443+
"num_blocks_per_group": num_blocks_per_group,
1444+
"non_blocking": non_blocking,
1445+
"use_stream": use_stream,
1446+
"record_stream": record_stream,
1447+
"low_cpu_mem_usage": low_cpu_mem_usage,
1448+
"offload_to_disk_path": offload_to_disk_path,
1449+
}
1450+
for name, component in self.components.items():
1451+
if name not in exclude_modules and isinstance(component, torch.nn.Module):
1452+
if hasattr(component, "enable_group_offload"):
1453+
component.enable_group_offload(**group_offload_kwargs)
1454+
else:
1455+
apply_group_offloading(module=component, **group_offload_kwargs)
1456+
1457+
if exclude_modules:
1458+
for module_name in exclude_modules:
1459+
module = getattr(self, module_name, None)
1460+
if module is not None and isinstance(module, torch.nn.Module):
1461+
module.to(onload_device)
1462+
logger.debug(f"Placed `{module_name}` on {onload_device} device as it was in `exclude_modules`.")
1463+
13371464
def reset_device_map(self):
13381465
r"""
13391466
Resets the device maps (if any) to None.

tests/pipelines/test_pipelines_common.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import numpy as np
1111
import PIL.Image
12+
import pytest
1213
import torch
1314
import torch.nn as nn
1415
from huggingface_hub import ModelCard, delete_repo
@@ -2362,6 +2363,73 @@ def test_pipeline_with_accelerator_device_map(self, expected_max_difference=1e-4
23622363
max_diff = np.abs(to_np(out) - to_np(loaded_out)).max()
23632364
self.assertLess(max_diff, expected_max_difference)
23642365

2366+
@require_torch_accelerator
2367+
def test_pipeline_level_group_offloading_sanity_checks(self):
2368+
components = self.get_dummy_components()
2369+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2370+
2371+
for name, component in pipe.components.items():
2372+
if hasattr(component, "_supports_group_offloading"):
2373+
if not component._supports_group_offloading:
2374+
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
2375+
2376+
module_names = sorted(
2377+
[name for name, component in pipe.components.items() if isinstance(component, torch.nn.Module)]
2378+
)
2379+
exclude_module_name = module_names[0]
2380+
offload_device = "cpu"
2381+
pipe.enable_group_offload(
2382+
onload_device=torch_device,
2383+
offload_device=offload_device,
2384+
offload_type="leaf_level",
2385+
exclude_modules=exclude_module_name,
2386+
)
2387+
excluded_module = getattr(pipe, exclude_module_name)
2388+
self.assertTrue(torch.device(excluded_module.device).type == torch.device(torch_device).type)
2389+
2390+
for name, component in pipe.components.items():
2391+
if name not in [exclude_module_name] and isinstance(component, torch.nn.Module):
2392+
# `component.device` prints the `onload_device` type. We should probably override the
2393+
# `device` property in `ModelMixin`.
2394+
component_device = next(component.parameters())[0].device
2395+
self.assertTrue(torch.device(component_device).type == torch.device(offload_device).type)
2396+
2397+
@require_torch_accelerator
2398+
def test_pipeline_level_group_offloading_inference(self, expected_max_difference=1e-4):
2399+
components = self.get_dummy_components()
2400+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2401+
2402+
for name, component in pipe.components.items():
2403+
if hasattr(component, "_supports_group_offloading"):
2404+
if not component._supports_group_offloading:
2405+
pytest.skip(f"{self.pipeline_class.__name__} is not suitable for this test.")
2406+
2407+
# Regular inference.
2408+
pipe = pipe.to(torch_device)
2409+
pipe.set_progress_bar_config(disable=None)
2410+
torch.manual_seed(0)
2411+
inputs = self.get_dummy_inputs(torch_device)
2412+
inputs["generator"] = torch.manual_seed(0)
2413+
out = pipe(**inputs)[0]
2414+
2415+
pipe.to("cpu")
2416+
del pipe
2417+
2418+
# Inference with offloading
2419+
pipe: DiffusionPipeline = self.pipeline_class(**components)
2420+
offload_device = "cpu"
2421+
pipe.enable_group_offload(
2422+
onload_device=torch_device,
2423+
offload_device=offload_device,
2424+
offload_type="leaf_level",
2425+
)
2426+
pipe.set_progress_bar_config(disable=None)
2427+
inputs["generator"] = torch.manual_seed(0)
2428+
out_offload = pipe(**inputs)[0]
2429+
2430+
max_diff = np.abs(to_np(out) - to_np(out_offload)).max()
2431+
self.assertLess(max_diff, expected_max_difference)
2432+
23652433

23662434
@is_staging_test
23672435
class PipelinePushToHubTester(unittest.TestCase):

0 commit comments

Comments
 (0)