Skip to content

Commit e123bbc

Browse files
committed
memmap
1 parent b3fa8c6 commit e123bbc

File tree

1 file changed

+47
-13
lines changed

1 file changed

+47
-13
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
3838
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
3939

40+
# Always use memory-efficient CPU offloading to minimize RAM usage
41+
4042
_SUPPORTED_PYTORCH_LAYERS = (
4143
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
4244
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -80,30 +82,61 @@ def onload_(self):
8082
self.stream.synchronize()
8183

8284
with context:
85+
# Only transfer parameters that aren't already on the target device
8386
for group_module in self.modules:
84-
group_module.to(self.onload_device, non_blocking=self.non_blocking)
87+
if any(p.device != self.onload_device for p in group_module.parameters()):
88+
group_module.to(self.onload_device, non_blocking=self.non_blocking)
89+
8590
if self.parameters is not None:
8691
for param in self.parameters:
87-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
92+
if param.device != self.onload_device:
93+
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
94+
8895
if self.buffers is not None:
8996
for buffer in self.buffers:
90-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
97+
if buffer.device != self.onload_device:
98+
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
9199

92100
def offload_(self):
93101
r"""Offloads the group of modules to the offload_device."""
94102
# Synchronize if using stream
95103
if self.stream is not None:
96104
torch.cuda.current_stream().synchronize()
97105

98-
# Use regular to() method for all cases - much simpler!
99-
for group_module in self.modules:
100-
group_module.to(self.offload_device, non_blocking=self.non_blocking)
101-
if self.parameters is not None:
102-
for param in self.parameters:
103-
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
104-
if self.buffers is not None:
105-
for buffer in self.buffers:
106-
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
106+
# For CPU offloading, use a method that preserves memory mapping benefits
107+
if self.offload_device.type == 'cpu':
108+
# Empty GPU cache before offloading to reduce memory fragmentation
109+
if torch.cuda.is_available():
110+
torch.cuda.empty_cache()
111+
112+
# Use to() method directly on modules
113+
for group_module in self.modules:
114+
# Don't make copies if already on CPU
115+
if any(p.device.type != 'cpu' for p in group_module.parameters()):
116+
group_module.to(self.offload_device, non_blocking=self.non_blocking)
117+
118+
# Handle explicit parameters - avoid copies when already on CPU
119+
if self.parameters is not None:
120+
for param in self.parameters:
121+
if param.device.type != 'cpu':
122+
# Let PyTorch handle the transfer which can preserve memory mapping
123+
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
124+
125+
# Handle buffers - avoid copies when already on CPU
126+
if self.buffers is not None:
127+
for buffer in self.buffers:
128+
if buffer.device.type != 'cpu':
129+
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
130+
else:
131+
# For non-CPU offloading, use the regular approach
132+
for group_module in self.modules:
133+
group_module.to(self.offload_device, non_blocking=self.non_blocking)
134+
if self.parameters is not None:
135+
for param in self.parameters:
136+
param.data = param.data.to(self.offload_device, non_blocking=self.non_blocking)
137+
if self.buffers is not None:
138+
for buffer in self.buffers:
139+
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
107140

108141
# After offloading, we can unpin the memory if configured to do so
109142
# We'll keep it pinned by default for better performance
@@ -314,7 +347,8 @@ def apply_group_offloading(
314347
If True, offloading and onloading is done with non-blocking data transfer.
315348
use_stream (`bool`, defaults to `False`):
316349
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
317-
overlapping computation and data transfer.
350+
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
351+
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
318352
319353
Example:
320354
```python

0 commit comments

Comments
 (0)