Skip to content

Commit 1475026

Browse files
committed
sliding-window
1 parent 878eb4c commit 1475026

File tree

1 file changed

+77
-2
lines changed

1 file changed

+77
-2
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,47 @@
2929
logger = get_logger(__name__) # pylint: disable=invalid-name
3030

3131

32+
class PinnedGroupManager:
33+
"""
34+
Manages a sliding window of pinned module groups to limit total CPU memory usage.
35+
Only keeps up to max_pinned_groups pinned at any given time, unpinning oldest
36+
ones when the limit is reached.
37+
"""
38+
def __init__(self, max_pinned_groups: Optional[int] = None):
39+
self.max_pinned_groups = max_pinned_groups
40+
self.pinned_groups = []
41+
42+
def register_group(self, group: "ModuleGroup") -> None:
43+
"""Register a group with the manager"""
44+
if self.max_pinned_groups is not None:
45+
group._pinned_group_manager = self
46+
47+
def on_group_pinned(self, group: "ModuleGroup") -> None:
48+
"""Called when a group is pinned, manages the sliding window"""
49+
if self.max_pinned_groups is None:
50+
return
51+
52+
# Add the newly pinned group to our tracking list
53+
if group not in self.pinned_groups:
54+
self.pinned_groups.append(group)
55+
56+
# If we've exceeded our limit, unpin the oldest group(s)
57+
while len(self.pinned_groups) > self.max_pinned_groups:
58+
oldest_group = self.pinned_groups.pop(0)
59+
# Only unpin if it's not the group we just pinned
60+
if oldest_group != group and oldest_group.pinned_memory:
61+
oldest_group.unpin_memory_()
62+
63+
def on_group_unpinned(self, group: "ModuleGroup") -> None:
64+
"""Called when a group is manually unpinned"""
65+
if self.max_pinned_groups is None:
66+
return
67+
68+
# Remove the group from our tracking list
69+
if group in self.pinned_groups:
70+
self.pinned_groups.remove(group)
71+
72+
3273
# fmt: off
3374
_GROUP_OFFLOADING = "group_offloading"
3475
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -58,6 +99,7 @@ def __init__(
5899
stream: Optional[torch.cuda.Stream] = None,
59100
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
60101
onload_self: bool = True,
102+
pinned_group_manager: Optional["PinnedGroupManager"] = None,
61103
) -> None:
62104
self.modules = modules
63105
self.offload_device = offload_device
@@ -72,6 +114,8 @@ def __init__(
72114
self.onload_self = onload_self
73115
# Track if we've pinned our group's memory
74116
self.pinned_memory = False
117+
# Reference to the pinned group manager for sliding window functionality
118+
self._pinned_group_manager = pinned_group_manager
75119

76120
if self.stream is not None and self.cpu_param_dict is None:
77121
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
@@ -94,6 +138,10 @@ def pin_memory_(self):
94138
self.cpu_param_dict[param] = pinned_data
95139
param.data = pinned_data
96140
self.pinned_memory = True
141+
142+
# Notify the manager that this group has been pinned
143+
if hasattr(self, "_pinned_group_manager") and self._pinned_group_manager is not None:
144+
self._pinned_group_manager.on_group_pinned(self)
97145

98146
def unpin_memory_(self):
99147
r"""Unpin the memory of this group's parameters to free up CPU RAM."""
@@ -112,6 +160,10 @@ def unpin_memory_(self):
112160
# Clear the CPU param dict
113161
self.cpu_param_dict = {}
114162
self.pinned_memory = False
163+
164+
# Notify the manager that this group has been unpinned
165+
if hasattr(self, "_pinned_group_manager") and self._pinned_group_manager is not None:
166+
self._pinned_group_manager.on_group_unpinned(self)
115167

116168
def onload_(self):
117169
r"""Onloads the group of modules to the onload_device."""
@@ -321,6 +373,7 @@ def apply_group_offloading(
321373
non_blocking: bool = False,
322374
use_stream: bool = False,
323375
unpin_after_use: bool = False,
376+
max_pinned_groups: Optional[int] = None,
324377
) -> None:
325378
r"""
326379
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -371,6 +424,10 @@ def apply_group_offloading(
371424
If True, pinned memory is released after a module group has been offloaded to CPU. This reduces CPU memory
372425
usage at the cost of potentially slower data transfer when the group is loaded again. Useful for large models
373426
with limited CPU RAM.
427+
max_pinned_groups (`int`, *optional*):
428+
If set, limits the number of module groups that can have pinned memory at any given time. When this limit
429+
is reached, the oldest pinned groups are unpinned to make room for new ones. This implements a sliding window
430+
approach to manage CPU memory usage while maintaining good performance for the most recently used groups.
374431
375432
Example:
376433
```python
@@ -389,6 +446,7 @@ def apply_group_offloading(
389446
... num_blocks_per_group=2,
390447
... use_stream=True,
391448
... unpin_after_use=False, # Set to True to reduce CPU memory usage
449+
... max_pinned_groups=5, # Limit to 5 pinned groups at a time
392450
... )
393451
```
394452
"""
@@ -401,16 +459,26 @@ def apply_group_offloading(
401459
raise ValueError("Using streams for data transfer requires a CUDA device.")
402460

403461
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
462+
463+
# Create a pinned group manager if max_pinned_groups is set
464+
pinned_group_manager = None
465+
if max_pinned_groups is not None and stream is not None:
466+
pinned_group_manager = PinnedGroupManager(max_pinned_groups)
467+
logger.info(f"Using sliding window approach with maximum of {max_pinned_groups} pinned groups")
404468

405469
if offload_type == "block_level":
406470
if num_blocks_per_group is None:
407471
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
408472

409473
_apply_group_offloading_block_level(
410-
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream, unpin_after_use
474+
module, num_blocks_per_group, offload_device, onload_device, non_blocking,
475+
stream, unpin_after_use, pinned_group_manager
411476
)
412477
elif offload_type == "leaf_level":
413-
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream, unpin_after_use)
478+
_apply_group_offloading_leaf_level(
479+
module, offload_device, onload_device, non_blocking,
480+
stream, unpin_after_use, pinned_group_manager
481+
)
414482
else:
415483
raise ValueError(f"Unsupported offload_type: {offload_type}")
416484

@@ -423,6 +491,7 @@ def _apply_group_offloading_block_level(
423491
non_blocking: bool,
424492
stream: Optional[torch.cuda.Stream] = None,
425493
unpin_after_use: bool = False,
494+
pinned_group_manager: Optional[PinnedGroupManager] = None,
426495
) -> None:
427496
r"""
428497
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -469,6 +538,7 @@ def _apply_group_offloading_block_level(
469538
stream=stream,
470539
cpu_param_dict=cpu_param_dict,
471540
onload_self=stream is None,
541+
pinned_group_manager=pinned_group_manager,
472542
)
473543
matched_module_groups.append(group)
474544
for j in range(i, i + len(current_modules)):
@@ -506,6 +576,7 @@ def _apply_group_offloading_block_level(
506576
stream=None,
507577
cpu_param_dict=None,
508578
onload_self=True,
579+
pinned_group_manager=pinned_group_manager,
509580
)
510581
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
511582
_apply_group_offloading_hook(module, unmatched_group, next_group, unpin_after_use)
@@ -518,6 +589,7 @@ def _apply_group_offloading_leaf_level(
518589
non_blocking: bool,
519590
stream: Optional[torch.cuda.Stream] = None,
520591
unpin_after_use: bool = False,
592+
pinned_group_manager: Optional[PinnedGroupManager] = None,
521593
) -> None:
522594
r"""
523595
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -559,6 +631,7 @@ def _apply_group_offloading_leaf_level(
559631
stream=stream,
560632
cpu_param_dict=cpu_param_dict,
561633
onload_self=True,
634+
pinned_group_manager=pinned_group_manager,
562635
)
563636
_apply_group_offloading_hook(submodule, group, None, unpin_after_use)
564637
modules_with_group_offloading.add(name)
@@ -604,6 +677,7 @@ def _apply_group_offloading_leaf_level(
604677
stream=stream,
605678
cpu_param_dict=cpu_param_dict,
606679
onload_self=True,
680+
pinned_group_manager=pinned_group_manager,
607681
)
608682
_apply_group_offloading_hook(parent_module, group, None, unpin_after_use)
609683

@@ -623,6 +697,7 @@ def _apply_group_offloading_leaf_level(
623697
stream=None,
624698
cpu_param_dict=None,
625699
onload_self=True,
700+
pinned_group_manager=pinned_group_manager,
626701
)
627702
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
628703

0 commit comments

Comments
 (0)