Skip to content

Commit 4e4842f

Browse files
committed
check if safetensors already exist.
1 parent d8179b1 commit 4e4842f

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,20 @@ def onload_(self):
216216
def offload_(self):
217217
r"""Offloads the group of modules to the offload_device."""
218218
if self.offload_to_disk_path:
219-
if not self._is_offloaded_to_disk:
219+
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
220+
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221+
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222+
# we perform a write.
223+
# Check if the file has been saved in this session or if it already exists on disk.
224+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
220225
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
221226
tensors_to_save = {
222227
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
223228
}
224229
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
225-
self._is_offloaded_to_disk = True
230+
231+
# The group is now considered offloaded to disk for the rest of the session.
232+
self._is_offloaded_to_disk = True
226233

227234
for tensor_obj in self.tensor_to_key.keys():
228235
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)

0 commit comments

Comments
 (0)