Skip to content

Commit e74b782

Browse files
committed
update
1 parent d6392b4 commit e74b782

File tree

1 file changed

+22
-102
lines changed

1 file changed

+22
-102
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 22 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -29,46 +29,7 @@
2929
logger = get_logger(__name__) # pylint: disable=invalid-name
3030

3131

32-
class PinnedGroupManager:
33-
"""
34-
Manages a sliding window of pinned module groups to limit total CPU memory usage.
35-
Only keeps up to max_pinned_groups pinned at any given time, unpinning oldest
36-
ones when the limit is reached.
37-
"""
38-
39-
def __init__(self, max_pinned_groups: Optional[int] = None):
40-
self.max_pinned_groups = max_pinned_groups
41-
self.pinned_groups = []
42-
43-
def register_group(self, group: "ModuleGroup") -> None:
44-
"""Register a group with the manager"""
45-
if self.max_pinned_groups is not None:
46-
group._pinned_group_manager = self
47-
48-
def on_group_pinned(self, group: "ModuleGroup") -> None:
49-
"""Called when a group is pinned, manages the sliding window"""
50-
if self.max_pinned_groups is None:
51-
return
52-
53-
# Add the newly pinned group to our tracking list
54-
if group not in self.pinned_groups:
55-
self.pinned_groups.append(group)
56-
57-
# If we've exceeded our limit, unpin the oldest group(s)
58-
while len(self.pinned_groups) > self.max_pinned_groups:
59-
oldest_group = self.pinned_groups.pop(0)
60-
# Only unpin if it's not the group we just pinned
61-
if oldest_group != group and oldest_group.pinned_memory:
62-
oldest_group.unpin_memory_()
63-
64-
def on_group_unpinned(self, group: "ModuleGroup") -> None:
65-
"""Called when a group is manually unpinned"""
66-
if self.max_pinned_groups is None:
67-
return
68-
69-
# Remove the group from our tracking list
70-
if group in self.pinned_groups:
71-
self.pinned_groups.remove(group)
32+
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
7233

7334

7435
# fmt: off
@@ -100,7 +61,6 @@ def __init__(
10061
stream: Optional[torch.cuda.Stream] = None,
10162
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
10263
onload_self: bool = True,
103-
pinned_group_manager: Optional["PinnedGroupManager"] = None,
10464
) -> None:
10565
self.modules = modules
10666
self.offload_device = offload_device
@@ -113,58 +73,38 @@ def __init__(
11373
self.stream = stream
11474
self.cpu_param_dict = cpu_param_dict
11575
self.onload_self = onload_self
116-
# Track if we've pinned our group's memory
117-
self.pinned_memory = False
118-
# Reference to the pinned group manager for sliding window functionality
119-
self._pinned_group_manager = pinned_group_manager
76+
# We still track if we've prepared the CPU dict for compatibility
77+
self.cpu_dict_prepared = False
12078

12179
if self.stream is not None and self.cpu_param_dict is None:
122-
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
80+
# Now we'll create the dict on demand
81+
self.cpu_param_dict = {}
12382

12483
def pin_memory_(self):
125-
r"""Pin the memory of this group's parameters for faster transfer."""
126-
if self.stream is not None and not self.pinned_memory:
127-
# Create the pinned memory dict just for this group's parameters
128-
self.cpu_param_dict = {}
84+
r"""Prepare the CPU parameter dictionary for reference (no pinning)."""
85+
if self.stream is not None and not self.cpu_dict_prepared:
86+
# Create a reference-only CPU parameter dict without pinning
12987
for group_module in self.modules:
13088
for param in group_module.parameters():
13189
if param.device == self.offload_device:
132-
pinned_data = param.data.pin_memory()
133-
self.cpu_param_dict[param] = pinned_data
134-
param.data = pinned_data
90+
# Store a reference without pinning
91+
self.cpu_param_dict[param] = param.data
13592
if self.parameters is not None:
13693
for param in self.parameters:
13794
if param.device == self.offload_device:
138-
pinned_data = param.data.pin_memory()
139-
self.cpu_param_dict[param] = pinned_data
140-
param.data = pinned_data
95+
# Store a reference without pinning
96+
self.cpu_param_dict[param] = param.data
97+
self.cpu_dict_prepared = True
98+
# For API compatibility
14199
self.pinned_memory = True
142100

143-
# Notify the manager that this group has been pinned
144-
if hasattr(self, "_pinned_group_manager") and self._pinned_group_manager is not None:
145-
self._pinned_group_manager.on_group_pinned(self)
146-
147101
def unpin_memory_(self):
148-
r"""Unpin the memory of this group's parameters to free up CPU RAM."""
102+
r"""No-op method kept for API compatibility."""
103+
# No need to unpin since we're not pinning memory anymore
149104
if self.stream is not None and self.pinned_memory:
150-
# Only unpin if we're currently on CPU (i.e., offloaded)
151-
for group_module in self.modules:
152-
if not any(p.device == self.onload_device for p in group_module.parameters()):
153-
for param in group_module.parameters():
154-
if param in self.cpu_param_dict and param.device == self.offload_device:
155-
# Create a new non-pinned copy and replace
156-
param.data = param.data.clone()
157-
if self.parameters is not None:
158-
for param in self.parameters:
159-
if param in self.cpu_param_dict and param.device == self.offload_device:
160-
param.data = param.data.clone()
161-
# Clear the CPU param dict
162-
self.cpu_param_dict = {}
105+
# Just mark as unpinned for compatibility
163106
self.pinned_memory = False
164-
165-
# Notify the manager that this group has been unpinned
166-
if hasattr(self, "_pinned_group_manager") and self._pinned_group_manager is not None:
167-
self._pinned_group_manager.on_group_unpinned(self)
107+
self.cpu_dict_prepared = False
168108

169109
def onload_(self):
170110
r"""Onloads the group of modules to the onload_device."""
@@ -378,7 +318,6 @@ def apply_group_offloading(
378318
non_blocking: bool = False,
379319
use_stream: bool = False,
380320
unpin_after_use: bool = False,
381-
max_pinned_groups: Optional[int] = None,
382321
) -> None:
383322
r"""
384323
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -426,13 +365,7 @@ def apply_group_offloading(
426365
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
427366
overlapping computation and data transfer.
428367
unpin_after_use (`bool`, defaults to `False`):
429-
If True, pinned memory is released after a module group has been offloaded to CPU. This reduces CPU memory
430-
usage at the cost of potentially slower data transfer when the group is loaded again. Useful for large models
431-
with limited CPU RAM.
432-
max_pinned_groups (`int`, *optional*):
433-
If set, limits the number of module groups that can have pinned memory at any given time. When this limit
434-
is reached, the oldest pinned groups are unpinned to make room for new ones. This implements a sliding window
435-
approach to manage CPU memory usage while maintaining good performance for the most recently used groups.
368+
Legacy parameter kept for API compatibility. Has no effect as we no longer use pinned memory.
436369
437370
Example:
438371
```python
@@ -450,8 +383,7 @@ def apply_group_offloading(
450383
... offload_type="block_level",
451384
... num_blocks_per_group=2,
452385
... use_stream=True,
453-
... unpin_after_use=False, # Set to True to reduce CPU memory usage
454-
... max_pinned_groups=5, # Limit to 5 pinned groups at a time
386+
... unpin_after_use=False, # Legacy parameter, no effect
455387
... )
456388
```
457389
"""
@@ -465,11 +397,7 @@ def apply_group_offloading(
465397

466398
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
467399

468-
# Create a pinned group manager if max_pinned_groups is set
469-
pinned_group_manager = None
470-
if max_pinned_groups is not None and stream is not None:
471-
pinned_group_manager = PinnedGroupManager(max_pinned_groups)
472-
logger.info(f"Using sliding window approach with maximum of {max_pinned_groups} pinned groups")
400+
# We no longer need a pinned group manager as we're not using pinned memory
473401

474402
if offload_type == "block_level":
475403
if num_blocks_per_group is None:
@@ -483,11 +411,10 @@ def apply_group_offloading(
483411
non_blocking,
484412
stream,
485413
unpin_after_use,
486-
pinned_group_manager,
487414
)
488415
elif offload_type == "leaf_level":
489416
_apply_group_offloading_leaf_level(
490-
module, offload_device, onload_device, non_blocking, stream, unpin_after_use, pinned_group_manager
417+
module, offload_device, onload_device, non_blocking, stream, unpin_after_use
491418
)
492419
else:
493420
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -501,7 +428,6 @@ def _apply_group_offloading_block_level(
501428
non_blocking: bool,
502429
stream: Optional[torch.cuda.Stream] = None,
503430
unpin_after_use: bool = False,
504-
pinned_group_manager: Optional[PinnedGroupManager] = None,
505431
) -> None:
506432
r"""
507433
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -548,7 +474,6 @@ def _apply_group_offloading_block_level(
548474
stream=stream,
549475
cpu_param_dict=cpu_param_dict,
550476
onload_self=stream is None,
551-
pinned_group_manager=pinned_group_manager,
552477
)
553478
matched_module_groups.append(group)
554479
for j in range(i, i + len(current_modules)):
@@ -586,7 +511,6 @@ def _apply_group_offloading_block_level(
586511
stream=None,
587512
cpu_param_dict=None,
588513
onload_self=True,
589-
pinned_group_manager=pinned_group_manager,
590514
)
591515
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
592516
_apply_group_offloading_hook(module, unmatched_group, next_group, unpin_after_use)
@@ -599,7 +523,6 @@ def _apply_group_offloading_leaf_level(
599523
non_blocking: bool,
600524
stream: Optional[torch.cuda.Stream] = None,
601525
unpin_after_use: bool = False,
602-
pinned_group_manager: Optional[PinnedGroupManager] = None,
603526
) -> None:
604527
r"""
605528
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -641,7 +564,6 @@ def _apply_group_offloading_leaf_level(
641564
stream=stream,
642565
cpu_param_dict=cpu_param_dict,
643566
onload_self=True,
644-
pinned_group_manager=pinned_group_manager,
645567
)
646568
_apply_group_offloading_hook(submodule, group, None, unpin_after_use)
647569
modules_with_group_offloading.add(name)
@@ -687,7 +609,6 @@ def _apply_group_offloading_leaf_level(
687609
stream=stream,
688610
cpu_param_dict=cpu_param_dict,
689611
onload_self=True,
690-
pinned_group_manager=pinned_group_manager,
691612
)
692613
_apply_group_offloading_hook(parent_module, group, None, unpin_after_use)
693614

@@ -707,7 +628,6 @@ def _apply_group_offloading_leaf_level(
707628
stream=None,
708629
cpu_param_dict=None,
709630
onload_self=True,
710-
pinned_group_manager=pinned_group_manager,
711631
)
712632
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
713633

0 commit comments

Comments
 (0)