Skip to content

Commit 438905d

Browse files
committed
update
1 parent 904f24d commit 438905d

File tree

1 file changed

+84
-28
lines changed

1 file changed

+84
-28
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 84 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -82,38 +82,65 @@ def onload_(self):
8282
self.stream.synchronize()
8383

8484
with context:
85-
# Use direct per-parameter transfers rather than module-level transfers
86-
# This gives us more control and potentially better memory management
87-
for group_module in self.modules:
88-
# Check if any parameter needs moving
89-
if any(p.device != self.onload_device for p in group_module.parameters()):
90-
for param in group_module.parameters():
91-
if param.device != self.onload_device:
92-
# Use direct CUDA transfer for each parameter
93-
if self.onload_device.type == "cuda":
94-
param.data = param.data.cuda(self.onload_device.index,
95-
non_blocking=self.non_blocking)
85+
# Use the most efficient module-level transfer when possible
86+
# This approach mirrors how PyTorch handles full model transfers
87+
if self.modules:
88+
for group_module in self.modules:
89+
# Only onload if some parameters are not on the target device
90+
if any(p.device != self.onload_device for p in group_module.parameters()):
91+
try:
92+
# Try the most efficient approach using _apply
93+
if hasattr(group_module, "_apply"):
94+
# This is what module.to() uses internally
95+
def to_device(t):
96+
if t.device != self.onload_device:
97+
if self.onload_device.type == "cuda":
98+
return t.cuda(self.onload_device.index,
99+
non_blocking=self.non_blocking)
100+
else:
101+
return t.to(self.onload_device,
102+
non_blocking=self.non_blocking)
103+
return t
104+
105+
# Apply to all tensors without unnecessary copies
106+
group_module._apply(to_device)
96107
else:
97-
param.data = param.data.to(self.onload_device,
98-
non_blocking=self.non_blocking)
99-
108+
# Fallback to direct parameter transfer
109+
for param in group_module.parameters():
110+
if param.device != self.onload_device:
111+
if self.onload_device.type == "cuda":
112+
param.data = param.data.cuda(self.onload_device.index,
113+
non_blocking=self.non_blocking)
114+
else:
115+
param.data = param.data.to(self.onload_device,
116+
non_blocking=self.non_blocking)
117+
except Exception as e:
118+
# If optimization fails, fall back to direct parameter transfer
119+
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
120+
for param in group_module.parameters():
121+
if param.device != self.onload_device:
122+
if self.onload_device.type == "cuda":
123+
param.data = param.data.cuda(self.onload_device.index,
124+
non_blocking=self.non_blocking)
125+
else:
126+
param.data = param.data.to(self.onload_device,
127+
non_blocking=self.non_blocking)
128+
100129
# Handle explicit parameters
101130
if self.parameters is not None:
102131
for param in self.parameters:
103132
if param.device != self.onload_device:
104-
# Use direct CUDA transfer for each parameter
105133
if self.onload_device.type == "cuda":
106134
param.data = param.data.cuda(self.onload_device.index,
107135
non_blocking=self.non_blocking)
108136
else:
109137
param.data = param.data.to(self.onload_device,
110138
non_blocking=self.non_blocking)
111-
139+
112140
# Handle buffers
113141
if self.buffers is not None:
114142
for buffer in self.buffers:
115143
if buffer.device != self.onload_device:
116-
# Use direct CUDA transfer for each buffer
117144
if self.onload_device.type == "cuda":
118145
buffer.data = buffer.data.cuda(self.onload_device.index,
119146
non_blocking=self.non_blocking)
@@ -133,29 +160,58 @@ def offload_(self):
133160
if torch.cuda.is_available():
134161
torch.cuda.empty_cache()
135162

136-
# Instead of using to() on the whole module which might create copies,
137-
# directly move each parameter's data to CPU with cpu() which uses
138-
# the memory-optimized path
139-
for group_module in self.modules:
140-
# Check if any parameter needs moving
141-
if any(p.device.type != "cpu" for p in group_module.parameters()):
142-
for param in group_module.parameters():
143-
if param.device.type != "cpu":
144-
# Use direct cpu() method which is more memory-efficient than to()
145-
param.data = param.data.cpu()
163+
# For most memory-efficient CPU offloading, let's use a special approach
164+
# that simulates a full model device transfer:
165+
# 1. We'll minimize RAM usage by avoiding both unnecessary copies and
166+
# the accumulation of wasted memory over time
167+
168+
# First, for module groups, look for the highest-level module and offload at that level
169+
if self.modules:
170+
# For each root module in the group
171+
for group_module in self.modules:
172+
# Only offload if some parameters are not on CPU
173+
if any(p.device.type != "cpu" for p in group_module.parameters()):
174+
try:
175+
# Try the lowest possible CPU memory approach - this works like model.to("cpu")
176+
# but at the module level
177+
if hasattr(group_module, "_apply"):
178+
# This internal PyTorch method is what to() uses but with less overhead
179+
def cpu_tensor(t):
180+
if t.device.type != "cpu":
181+
return t.cpu()
182+
return t
183+
184+
# Apply to all tensors in the module without unnecessary copies
185+
group_module._apply(cpu_tensor)
186+
else:
187+
# Fallback to the direct method
188+
for param in group_module.parameters():
189+
if param.device.type != "cpu":
190+
param.data = param.data.cpu()
191+
except Exception as e:
192+
# If for any reason the optimized approach fails, fall back to direct method
193+
logger.warning(f"Optimized CPU offloading failed: {e}, falling back to direct method")
194+
for param in group_module.parameters():
195+
if param.device.type != "cpu":
196+
param.data = param.data.cpu()
146197

147198
# Handle explicit parameters - move directly to CPU
148199
if self.parameters is not None:
149200
for param in self.parameters:
150201
if param.device.type != "cpu":
151-
# Direct CPU transfer with cpu() method
202+
# Direct CPU transfer
152203
param.data = param.data.cpu()
153204

154205
# Handle buffers - move directly to CPU
155206
if self.buffers is not None:
156207
for buffer in self.buffers:
157208
if buffer.device.type != "cpu":
158209
buffer.data = buffer.data.cpu()
210+
211+
# Force garbage collection to clean up any released memory
212+
import gc
213+
gc.collect()
214+
159215
else:
160216
# For non-CPU offloading, synchronize if using stream
161217
if self.stream is not None:

0 commit comments

Comments
 (0)