Skip to content

Commit 878eb4c

Browse files
committed
update
1 parent 9add071 commit 878eb4c

File tree

1 file changed

+74
-19
lines changed

1 file changed

+74
-19
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,55 @@ def __init__(
7070
self.stream = stream
7171
self.cpu_param_dict = cpu_param_dict
7272
self.onload_self = onload_self
73+
# Track if we've pinned our group's memory
74+
self.pinned_memory = False
7375

7476
if self.stream is not None and self.cpu_param_dict is None:
7577
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
7678

79+
def pin_memory_(self):
80+
r"""Pin the memory of this group's parameters for faster transfer."""
81+
if self.stream is not None and not self.pinned_memory:
82+
# Create the pinned memory dict just for this group's parameters
83+
self.cpu_param_dict = {}
84+
for group_module in self.modules:
85+
for param in group_module.parameters():
86+
if param.device == self.offload_device:
87+
pinned_data = param.data.pin_memory()
88+
self.cpu_param_dict[param] = pinned_data
89+
param.data = pinned_data
90+
if self.parameters is not None:
91+
for param in self.parameters:
92+
if param.device == self.offload_device:
93+
pinned_data = param.data.pin_memory()
94+
self.cpu_param_dict[param] = pinned_data
95+
param.data = pinned_data
96+
self.pinned_memory = True
97+
98+
def unpin_memory_(self):
99+
r"""Unpin the memory of this group's parameters to free up CPU RAM."""
100+
if self.stream is not None and self.pinned_memory:
101+
# Only unpin if we're currently on CPU (i.e., offloaded)
102+
for group_module in self.modules:
103+
if not any(p.device == self.onload_device for p in group_module.parameters()):
104+
for param in group_module.parameters():
105+
if param in self.cpu_param_dict and param.device == self.offload_device:
106+
# Create a new non-pinned copy and replace
107+
param.data = param.data.clone()
108+
if self.parameters is not None:
109+
for param in self.parameters:
110+
if param in self.cpu_param_dict and param.device == self.offload_device:
111+
param.data = param.data.clone()
112+
# Clear the CPU param dict
113+
self.cpu_param_dict = {}
114+
self.pinned_memory = False
115+
77116
def onload_(self):
78117
r"""Onloads the group of modules to the onload_device."""
118+
# Pin memory before onloading
119+
if self.stream is not None and not self.pinned_memory:
120+
self.pin_memory_()
121+
79122
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
80123
if self.stream is not None:
81124
# Wait for previous Host->Device transfer to complete
@@ -107,6 +150,9 @@ def offload_(self):
107150
if self.buffers is not None:
108151
for buffer in self.buffers:
109152
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
153+
154+
# After offloading, we can unpin the memory if configured to do so
155+
# We'll keep it pinned by default for better performance
110156

111157

112158
class GroupOffloadingHook(ModelHook):
@@ -123,9 +169,11 @@ def __init__(
123169
self,
124170
group: ModuleGroup,
125171
next_group: Optional[ModuleGroup] = None,
172+
unpin_after_use: bool = False,
126173
) -> None:
127174
self.group = group
128175
self.next_group = next_group
176+
self.unpin_after_use = unpin_after_use
129177

130178
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
131179
if self.group.offload_leader == module:
@@ -154,6 +202,10 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
154202
def post_forward(self, module: torch.nn.Module, output):
155203
if self.group.offload_leader == module:
156204
self.group.offload_()
205+
# After offloading, we can optionally unpin memory to free up CPU RAM
206+
# This is most useful for large models where CPU RAM is limited
207+
if self.unpin_after_use and self.group.pinned_memory:
208+
self.group.unpin_memory_()
157209
return output
158210

159211

@@ -268,6 +320,7 @@ def apply_group_offloading(
268320
num_blocks_per_group: Optional[int] = None,
269321
non_blocking: bool = False,
270322
use_stream: bool = False,
323+
unpin_after_use: bool = False,
271324
) -> None:
272325
r"""
273326
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -314,6 +367,10 @@ def apply_group_offloading(
314367
use_stream (`bool`, defaults to `False`):
315368
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
316369
overlapping computation and data transfer.
370+
unpin_after_use (`bool`, defaults to `False`):
371+
If True, pinned memory is released after a module group has been offloaded to CPU. This reduces CPU memory
372+
usage at the cost of potentially slower data transfer when the group is loaded again. Useful for large models
373+
with limited CPU RAM.
317374
318375
Example:
319376
```python
@@ -331,6 +388,7 @@ def apply_group_offloading(
331388
... offload_type="block_level",
332389
... num_blocks_per_group=2,
333390
... use_stream=True,
391+
... unpin_after_use=False, # Set to True to reduce CPU memory usage
334392
... )
335393
```
336394
"""
@@ -349,10 +407,10 @@ def apply_group_offloading(
349407
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
350408

351409
_apply_group_offloading_block_level(
352-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
410+
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, unpin_after_use
353411
)
354412
elif offload_type == "leaf_level":
355-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
413+
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, unpin_after_use)
356414
else:
357415
raise ValueError(f"Unsupported offload_type: {offload_type}")
358416

@@ -364,6 +422,7 @@ def _apply_group_offloading_block_level(
364422
onload_device: torch.device,
365423
non_blocking: bool,
366424
stream: Optional[torch.cuda.Stream] = None,
425+
unpin_after_use: bool = False,
367426
) -> None:
368427
r"""
369428
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -384,12 +443,9 @@ def _apply_group_offloading_block_level(
384443
for overlapping computation and data transfer.
385444
"""
386445

387-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
388-
cpu_param_dict = None
389-
if stream is not None:
390-
for param in module.parameters():
391-
param.data = param.data.cpu().pin_memory()
392-
cpu_param_dict = {param: param.data for param in module.parameters()}
446+
# With progressive pinning approach, we'll initialize an empty CPU parameter dict
447+
# and pin memory only when needed by each group
448+
cpu_param_dict = {} if stream is not None else None
393449

394450
# Create module groups for ModuleList and Sequential blocks
395451
modules_with_group_offloading = set()
@@ -425,7 +481,7 @@ def _apply_group_offloading_block_level(
425481
)
426482

427483
for group_module in group.modules:
428-
_apply_group_offloading_hook(group_module, group, next_group)
484+
_apply_group_offloading_hook(group_module, group, next_group, unpin_after_use)
429485

430486
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
431487
# when the forward pass of this module is called. This is because the top-level module is not
@@ -452,7 +508,7 @@ def _apply_group_offloading_block_level(
452508
onload_self=True,
453509
)
454510
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
455-
_apply_group_offloading_hook(module, unmatched_group, next_group)
511+
_apply_group_offloading_hook(module, unmatched_group, next_group, unpin_after_use)
456512

457513

458514
def _apply_group_offloading_leaf_level(
@@ -461,6 +517,7 @@ def _apply_group_offloading_leaf_level(
461517
onload_device: torch.device,
462518
non_blocking: bool,
463519
stream: Optional[torch.cuda.Stream] = None,
520+
unpin_after_use: bool = False,
464521
) -> None:
465522
r"""
466523
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -483,12 +540,9 @@ def _apply_group_offloading_leaf_level(
483540
for overlapping computation and data transfer.
484541
"""
485542

486-
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
487-
cpu_param_dict = None
488-
if stream is not None:
489-
for param in module.parameters():
490-
param.data = param.data.cpu().pin_memory()
491-
cpu_param_dict = {param: param.data for param in module.parameters()}
543+
# With progressive pinning approach, we'll initialize an empty CPU parameter dict
544+
# and pin memory only when needed by each group
545+
cpu_param_dict = {} if stream is not None else None
492546

493547
# Create module groups for leaf modules and apply group offloading hooks
494548
modules_with_group_offloading = set()
@@ -506,7 +560,7 @@ def _apply_group_offloading_leaf_level(
506560
cpu_param_dict=cpu_param_dict,
507561
onload_self=True,
508562
)
509-
_apply_group_offloading_hook(submodule, group, None)
563+
_apply_group_offloading_hook(submodule, group, None, unpin_after_use)
510564
modules_with_group_offloading.add(name)
511565

512566
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
@@ -551,7 +605,7 @@ def _apply_group_offloading_leaf_level(
551605
cpu_param_dict=cpu_param_dict,
552606
onload_self=True,
553607
)
554-
_apply_group_offloading_hook(parent_module, group, None)
608+
_apply_group_offloading_hook(parent_module, group, None, unpin_after_use)
555609

556610
if stream is not None:
557611
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
@@ -577,13 +631,14 @@ def _apply_group_offloading_hook(
577631
module: torch.nn.Module,
578632
group: ModuleGroup,
579633
next_group: Optional[ModuleGroup] = None,
634+
unpin_after_use: bool = False,
580635
) -> None:
581636
registry = HookRegistry.check_if_exists_or_initialize(module)
582637

583638
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
584639
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
585640
if registry.get_hook(_GROUP_OFFLOADING) is None:
586-
hook = GroupOffloadingHook(group, next_group)
641+
hook = GroupOffloadingHook(group, next_group, unpin_after_use)
587642
registry.register_hook(hook, _GROUP_OFFLOADING)
588643

589644

0 commit comments

Comments
 (0)