Skip to content

Commit 4cfda51

Browse files
committed
update
1 parent 85a916b commit 4cfda51

File tree

1 file changed

+82
-58
lines changed

1 file changed

+82
-58
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 82 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,58 @@ def _pinned_memory_tensors(self):
135135
finally:
136136
pinned_dict = None
137137

138+
def _transfer_tensor_to_device(self, tensor, source_tensor, current_stream=None):
139+
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
140+
if self.record_stream and current_stream is not None:
141+
tensor.data.record_stream(current_stream)
142+
143+
def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None):
144+
for group_module in self.modules:
145+
for param in group_module.parameters():
146+
source = pinned_memory[param] if pinned_memory else param.data
147+
self._transfer_tensor_to_device(param, source, current_stream)
148+
for buffer in group_module.buffers():
149+
source = pinned_memory[buffer] if pinned_memory else buffer.data
150+
self._transfer_tensor_to_device(buffer, source, current_stream)
151+
152+
for param in self.parameters:
153+
source = pinned_memory[param] if pinned_memory else param.data
154+
self._transfer_tensor_to_device(param, source, current_stream)
155+
156+
for buffer in self.buffers:
157+
source = pinned_memory[buffer] if pinned_memory else buffer.data
158+
self._transfer_tensor_to_device(buffer, source, current_stream)
159+
160+
def _onload_from_disk(self, current_stream):
161+
if self.stream is not None:
162+
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
163+
164+
for key, tensor_obj in self.key_to_tensor.items():
165+
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
166+
167+
with self._pinned_memory_tensors() as pinned_memory:
168+
for key, tensor_obj in self.key_to_tensor.items():
169+
self._transfer_tensor_to_device(tensor_obj, pinned_memory[tensor_obj], current_stream)
170+
171+
self.cpu_param_dict.clear()
172+
173+
else:
174+
onload_device = (
175+
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
176+
)
177+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
178+
for key, tensor_obj in self.key_to_tensor.items():
179+
tensor_obj.data = loaded_tensors[key]
180+
181+
def _onload_from_memory(self, current_stream):
182+
if self.stream is not None:
183+
with self._pinned_memory_tensors() as pinned_memory:
184+
self._process_tensors_from_modules(pinned_memory, current_stream)
185+
else:
186+
self._process_tensors_from_modules(None, current_stream)
187+
138188
@torch.compiler.disable()
139189
def onload_(self):
140-
r"""Onloads the group of modules to the onload_device."""
141190
torch_accelerator_module = (
142191
getattr(torch, torch.accelerator.current_accelerator().type)
143192
if hasattr(torch, "accelerator")
@@ -175,67 +224,32 @@ def onload_(self):
175224
self.stream.synchronize()
176225

177226
with context:
178-
if self.stream is not None:
179-
with self._pinned_memory_tensors() as pinned_memory:
180-
for group_module in self.modules:
181-
for param in group_module.parameters():
182-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
183-
if self.record_stream:
184-
param.data.record_stream(current_stream)
185-
for buffer in group_module.buffers():
186-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
187-
if self.record_stream:
188-
buffer.data.record_stream(current_stream)
189-
190-
for param in self.parameters:
191-
param.data = pinned_memory[param].to(self.onload_device, non_blocking=self.non_blocking)
192-
if self.record_stream:
193-
param.data.record_stream(current_stream)
194-
195-
for buffer in self.buffers:
196-
buffer.data = pinned_memory[buffer].to(self.onload_device, non_blocking=self.non_blocking)
197-
if self.record_stream:
198-
buffer.data.record_stream(current_stream)
199-
227+
if self.offload_to_disk_path:
228+
self._onload_from_disk(current_stream)
200229
else:
201-
for group_module in self.modules:
202-
for param in group_module.parameters():
203-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
204-
for buffer in group_module.buffers():
205-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
206-
207-
for param in self.parameters:
208-
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
209-
210-
for buffer in self.buffers:
211-
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
212-
if self.record_stream:
213-
buffer.data.record_stream(current_stream)
230+
self._onload_from_memory(current_stream)
214231

215232
@torch.compiler.disable()
216-
def offload_(self):
217-
r"""Offloads the group of modules to the offload_device."""
218-
if self.offload_to_disk_path:
219-
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
220-
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221-
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222-
# we perform a write.
223-
# Check if the file has been saved in this session or if it already exists on disk.
224-
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
225-
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
226-
tensors_to_save = {
227-
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
228-
}
229-
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
230-
231-
# The group is now considered offloaded to disk for the rest of the session.
232-
self._is_offloaded_to_disk = True
233-
234-
# We do this to free up the RAM which is still holding the up tensor data.
235-
for tensor_obj in self.tensor_to_key.keys():
236-
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
237-
return
233+
def _offload_to_disk(self):
234+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
235+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
236+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
237+
# we perform a write.
238+
# Check if the file has been saved in this session or if it already exists on disk.
239+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
240+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
241+
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
242+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
243+
244+
# The group is now considered offloaded to disk for the rest of the session.
245+
self._is_offloaded_to_disk = True
246+
247+
# We do this to free up the RAM which is still holding the up tensor data.
248+
for tensor_obj in self.tensor_to_key.keys():
249+
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
238250

251+
@torch.compiler.disable()
252+
def _offload_to_memory(self):
239253
torch_accelerator_module = (
240254
getattr(torch, torch.accelerator.current_accelerator().type)
241255
if hasattr(torch, "accelerator")
@@ -260,6 +274,14 @@ def offload_(self):
260274
for buffer in self.buffers:
261275
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
262276

277+
@torch.compiler.disable()
278+
def offload_(self):
279+
r"""Offloads the group of modules to the offload_device."""
280+
if self.offload_to_disk_path:
281+
self._offload_to_disk()
282+
else:
283+
self._offload_to_memory()
284+
263285

264286
class GroupOffloadingHook(ModelHook):
265287
r"""
@@ -484,6 +506,8 @@ def apply_group_offloading(
484506
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
485507
the CPU memory is a bottleneck but may counteract the benefits of using streams.
486508
509+
(TODO: include example with `offload_to_disk_path`)
510+
487511
Example:
488512
```python
489513
>>> from diffusers import CogVideoXTransformer3DModel

0 commit comments

Comments
 (0)