Skip to content

Commit d6392b4

Browse files
committed
update
1 parent 1475026 commit d6392b4

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,36 +35,37 @@ class PinnedGroupManager:
3535
Only keeps up to max_pinned_groups pinned at any given time, unpinning oldest
3636
ones when the limit is reached.
3737
"""
38+
3839
def __init__(self, max_pinned_groups: Optional[int] = None):
3940
self.max_pinned_groups = max_pinned_groups
4041
self.pinned_groups = []
41-
42+
4243
def register_group(self, group: "ModuleGroup") -> None:
4344
"""Register a group with the manager"""
4445
if self.max_pinned_groups is not None:
4546
group._pinned_group_manager = self
46-
47+
4748
def on_group_pinned(self, group: "ModuleGroup") -> None:
4849
"""Called when a group is pinned, manages the sliding window"""
4950
if self.max_pinned_groups is None:
5051
return
51-
52+
5253
# Add the newly pinned group to our tracking list
5354
if group not in self.pinned_groups:
5455
self.pinned_groups.append(group)
55-
56+
5657
# If we've exceeded our limit, unpin the oldest group(s)
5758
while len(self.pinned_groups) > self.max_pinned_groups:
5859
oldest_group = self.pinned_groups.pop(0)
5960
# Only unpin if it's not the group we just pinned
6061
if oldest_group != group and oldest_group.pinned_memory:
6162
oldest_group.unpin_memory_()
62-
63+
6364
def on_group_unpinned(self, group: "ModuleGroup") -> None:
6465
"""Called when a group is manually unpinned"""
6566
if self.max_pinned_groups is None:
6667
return
67-
68+
6869
# Remove the group from our tracking list
6970
if group in self.pinned_groups:
7071
self.pinned_groups.remove(group)
@@ -138,7 +139,7 @@ def pin_memory_(self):
138139
self.cpu_param_dict[param] = pinned_data
139140
param.data = pinned_data
140141
self.pinned_memory = True
141-
142+
142143
# Notify the manager that this group has been pinned
143144
if hasattr(self, "_pinned_group_manager") and self._pinned_group_manager is not None:
144145
self._pinned_group_manager.on_group_pinned(self)
@@ -160,7 +161,7 @@ def unpin_memory_(self):
160161
# Clear the CPU param dict
161162
self.cpu_param_dict = {}
162163
self.pinned_memory = False
163-
164+
164165
# Notify the manager that this group has been unpinned
165166
if hasattr(self, "_pinned_group_manager") and self._pinned_group_manager is not None:
166167
self._pinned_group_manager.on_group_unpinned(self)
@@ -170,7 +171,7 @@ def onload_(self):
170171
# Pin memory before onloading
171172
if self.stream is not None and not self.pinned_memory:
172173
self.pin_memory_()
173-
174+
174175
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
175176
if self.stream is not None:
176177
# Wait for previous Host->Device transfer to complete
@@ -202,7 +203,7 @@ def offload_(self):
202203
if self.buffers is not None:
203204
for buffer in self.buffers:
204205
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
205-
206+
206207
# After offloading, we can unpin the memory if configured to do so
207208
# We'll keep it pinned by default for better performance
208209

@@ -229,6 +230,10 @@ def __init__(
229230

230231
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
231232
if self.group.offload_leader == module:
233+
# Make sure we pin memory first (if using streams) before offloading
234+
if self.group.stream is not None and not self.group.pinned_memory:
235+
self.group.pin_memory_()
236+
# Now it's safe to offload
232237
self.group.offload_()
233238
return module
234239

@@ -459,7 +464,7 @@ def apply_group_offloading(
459464
raise ValueError("Using streams for data transfer requires a CUDA device.")
460465

461466
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
462-
467+
463468
# Create a pinned group manager if max_pinned_groups is set
464469
pinned_group_manager = None
465470
if max_pinned_groups is not None and stream is not None:
@@ -471,13 +476,18 @@ def apply_group_offloading(
471476
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
472477

473478
_apply_group_offloading_block_level(
474-
module, num_blocks_per_group, offload_device, onload_device, non_blocking,
475-
stream, unpin_after_use, pinned_group_manager
479+
module,
480+
num_blocks_per_group,
481+
offload_device,
482+
onload_device,
483+
non_blocking,
484+
stream,
485+
unpin_after_use,
486+
pinned_group_manager,
476487
)
477488
elif offload_type == "leaf_level":
478489
_apply_group_offloading_leaf_level(
479-
module, offload_device, onload_device, non_blocking,
480-
stream, unpin_after_use, pinned_group_manager
490+
module, offload_device, onload_device, non_blocking, stream, unpin_after_use, pinned_group_manager
481491
)
482492
else:
483493
raise ValueError(f"Unsupported offload_type: {offload_type}")

0 commit comments

Comments
 (0)