Skip to content

Commit 720be2b

Browse files
committed
update
1 parent e74b782 commit 720be2b

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def unpin_memory_(self):
108108

109109
def onload_(self):
110110
r"""Onloads the group of modules to the onload_device."""
111-
# Pin memory before onloading
112-
if self.stream is not None and not self.pinned_memory:
113-
self.pin_memory_()
111+
# Prepare CPU dict before onloading
112+
if self.stream is not None and not self.cpu_dict_prepared:
113+
self.pin_memory_() # This now just prepares the CPU dict
114114

115115
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
116116
if self.stream is not None:
@@ -170,9 +170,9 @@ def __init__(
170170

171171
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
172172
if self.group.offload_leader == module:
173-
# Make sure we pin memory first (if using streams) before offloading
174-
if self.group.stream is not None and not self.group.pinned_memory:
175-
self.group.pin_memory_()
173+
# Make sure we prepare CPU dict first (if using streams) before offloading
174+
if self.group.stream is not None and not self.group.cpu_dict_prepared:
175+
self.group.pin_memory_() # This now just prepares the CPU dict
176176
# Now it's safe to offload
177177
self.group.offload_()
178178
return module
@@ -199,9 +199,8 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
199199
def post_forward(self, module: torch.nn.Module, output):
200200
if self.group.offload_leader == module:
201201
self.group.offload_()
202-
# After offloading, we can optionally unpin memory to free up CPU RAM
203-
# This is most useful for large models where CPU RAM is limited
204-
if self.unpin_after_use and self.group.pinned_memory:
202+
# This is now a no-op but kept for API compatibility
203+
if self.unpin_after_use and self.group.cpu_dict_prepared:
205204
self.group.unpin_memory_()
206205
return output
207206

0 commit comments

Comments
 (0)