Skip to content

Commit d9915a7

Browse files
committed
update
1 parent b7a795d commit d9915a7

File tree

1 file changed

+19
-34
lines changed

1 file changed

+19
-34
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def to_device(t):
150150

151151
def offload_(self):
152152
r"""Offloads the group of modules to the offload_device."""
153-
# For CPU offloading, use the most memory-efficient approach possible
153+
# For CPU offloading
154154
if self.offload_device.type == "cpu":
155155
# Synchronize if using stream
156156
if self.stream is not None:
@@ -160,53 +160,38 @@ def offload_(self):
160160
if torch.cuda.is_available():
161161
torch.cuda.empty_cache()
162162

163-
# For most memory-efficient CPU offloading, let's use a special approach
164-
# that simulates a full model device transfer:
165-
# 1. We'll minimize RAM usage by avoiding both unnecessary copies and
166-
# the accumulation of wasted memory over time
167-
168-
# First, for module groups, look for the highest-level module and offload at that level
163+
# For module groups, use a single, unified approach that is closest to
164+
# the behavior of model.to("cpu")
169165
if self.modules:
170-
# For each root module in the group
171166
for group_module in self.modules:
172-
# Only offload if some parameters are not on CPU
167+
# Check if we need to offload this module
173168
if any(p.device.type != "cpu" for p in group_module.parameters()):
169+
# Use PyTorch's built-in to() method directly, which preserves
170+
# memory mapping when moving to CPU
174171
try:
175-
# Try the lowest possible CPU memory approach - this works like model.to("cpu")
176-
# but at the module level
177-
if hasattr(group_module, "_apply"):
178-
# This internal PyTorch method is what to() uses but with less overhead
179-
def cpu_tensor(t):
180-
if t.device.type != "cpu":
181-
return t.cpu()
182-
return t
183-
184-
# Apply to all tensors in the module without unnecessary copies
185-
group_module._apply(cpu_tensor)
186-
else:
187-
# Fallback to the direct method
188-
for param in group_module.parameters():
189-
if param.device.type != "cpu":
190-
param.data = param.data.cpu()
172+
# Non-blocking=False for CPU transfers, as it ensures memory is
173+
# immediately available and potentially preserves memory mapping
174+
group_module.to("cpu", non_blocking=False)
191175
except Exception as e:
192-
# If for any reason the optimized approach fails, fall back to direct method
193-
logger.warning(f"Optimized CPU offloading failed: {e}, falling back to direct method")
176+
# If there's any error, fall back to parameter-level offloading
177+
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
194178
for param in group_module.parameters():
195179
if param.device.type != "cpu":
196-
param.data = param.data.cpu()
197-
198-
# Handle explicit parameters - move directly to CPU
180+
param.data = param.data.to("cpu", non_blocking=False)
181+
182+
# Handle explicit parameters - move directly to CPU with non-blocking=False
183+
# which can preserve memory mapping in some PyTorch versions
199184
if self.parameters is not None:
200185
for param in self.parameters:
201186
if param.device.type != "cpu":
202-
# Direct CPU transfer
203-
param.data = param.data.cpu()
187+
param.data = param.data.to("cpu", non_blocking=False)
204188

205-
# Handle buffers - move directly to CPU
189+
# Handle buffers
206190
if self.buffers is not None:
207191
for buffer in self.buffers:
208192
if buffer.device.type != "cpu":
209-
buffer.data = buffer.data.cpu()
193+
buffer.data = buffer.data.to("cpu", non_blocking=False)
194+
210195
# Let Python's normal reference counting handle cleanup
211196
# We don't force garbage collection to avoid slowing down inference
212197

0 commit comments

Comments
 (0)