Skip to content

Commit 904f24d

Browse files
committed
update
1 parent e123bbc commit 904f24d

File tree

1 file changed

+58
-27
lines changed

1 file changed

+58
-27
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 58 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -82,52 +82,85 @@ def onload_(self):
8282
self.stream.synchronize()
8383

8484
with context:
85-
# Only transfer parameters that aren't already on the target device
85+
# Use direct per-parameter transfers rather than module-level transfers
86+
# This gives us more control and potentially better memory management
8687
for group_module in self.modules:
88+
# Check if any parameter needs moving
8789
if any(p.device != self.onload_device for p in group_module.parameters()):
88-
group_module.to(self.onload_device, non_blocking=self.non_blocking)
89-
90+
for param in group_module.parameters():
91+
if param.device != self.onload_device:
92+
# Use direct CUDA transfer for each parameter
93+
if self.onload_device.type == "cuda":
94+
param.data = param.data.cuda(self.onload_device.index,
95+
non_blocking=self.non_blocking)
96+
else:
97+
param.data = param.data.to(self.onload_device,
98+
non_blocking=self.non_blocking)
99+
100+
# Handle explicit parameters
90101
if self.parameters is not None:
91102
for param in self.parameters:
92103
if param.device != self.onload_device:
93-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
94-
104+
# Use direct CUDA transfer for each parameter
105+
if self.onload_device.type == "cuda":
106+
param.data = param.data.cuda(self.onload_device.index,
107+
non_blocking=self.non_blocking)
108+
else:
109+
param.data = param.data.to(self.onload_device,
110+
non_blocking=self.non_blocking)
111+
112+
# Handle buffers
95113
if self.buffers is not None:
96114
for buffer in self.buffers:
97115
if buffer.device != self.onload_device:
98-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
116+
# Use direct CUDA transfer for each buffer
117+
if self.onload_device.type == "cuda":
118+
buffer.data = buffer.data.cuda(self.onload_device.index,
119+
non_blocking=self.non_blocking)
120+
else:
121+
buffer.data = buffer.data.to(self.onload_device,
122+
non_blocking=self.non_blocking)
99123

100124
def offload_(self):
101125
r"""Offloads the group of modules to the offload_device."""
102-
# Synchronize if using stream
103-
if self.stream is not None:
104-
torch.cuda.current_stream().synchronize()
105-
106-
# For CPU offloading, use a method that preserves memory mapping benefits
107-
if self.offload_device.type == 'cpu':
126+
# For CPU offloading, use the most memory-efficient approach possible
127+
if self.offload_device.type == "cpu":
128+
# Synchronize if using stream
129+
if self.stream is not None:
130+
torch.cuda.current_stream().synchronize()
131+
108132
# Empty GPU cache before offloading to reduce memory fragmentation
109133
if torch.cuda.is_available():
110134
torch.cuda.empty_cache()
111135

112-
# Use to() method directly on modules
136+
# Instead of using to() on the whole module which might create copies,
137+
# directly move each parameter's data to CPU with cpu() which uses
138+
# the memory-optimized path
113139
for group_module in self.modules:
114-
# Don't make copies if already on CPU
115-
if any(p.device.type != 'cpu' for p in group_module.parameters()):
116-
group_module.to(self.offload_device, non_blocking=self.non_blocking)
117-
118-
# Handle explicit parameters - avoid copies when already on CPU
140+
# Check if any parameter needs moving
141+
if any(p.device.type != "cpu" for p in group_module.parameters()):
142+
for param in group_module.parameters():
143+
if param.device.type != "cpu":
144+
# Use direct cpu() method which is more memory-efficient than to()
145+
param.data = param.data.cpu()
146+
147+
# Handle explicit parameters - move directly to CPU
119148
if self.parameters is not None:
120149
for param in self.parameters:
121-
if param.device.type != 'cpu':
122-
# Let PyTorch handle the transfer which can preserve memory mapping
123-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
150+
if param.device.type != "cpu":
151+
# Direct CPU transfer with cpu() method
152+
param.data = param.data.cpu()
124153

125-
# Handle buffers - avoid copies when already on CPU
154+
# Handle buffers - move directly to CPU
126155
if self.buffers is not None:
127156
for buffer in self.buffers:
128-
if buffer.device.type != 'cpu':
129-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
157+
if buffer.device.type != "cpu":
158+
buffer.data = buffer.data.cpu()
130159
else:
160+
# For non-CPU offloading, synchronize if using stream
161+
if self.stream is not None:
162+
torch.cuda.current_stream().synchronize()
163+
131164
# For non-CPU offloading, use the regular approach
132165
for group_module in self.modules:
133166
group_module.to(self.offload_device, non_blocking=self.non_blocking)
@@ -394,9 +427,7 @@ def apply_group_offloading(
394427
stream,
395428
)
396429
elif offload_type == "leaf_level":
397-
_apply_group_offloading_leaf_level(
398-
module, offload_device, onload_device, non_blocking, stream
399-
)
430+
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
400431
else:
401432
raise ValueError(f"Unsupported offload_type: {offload_type}")
402433

0 commit comments

Comments
 (0)