Skip to content

Commit 4854309

Browse files
committed
update
1 parent a9e9ef5 commit 4854309

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
self._is_offloaded_to_disk = False
8282

8383
if self.offload_to_disk_path:
84+
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
8485
all_tensors = []
8586
param_names = []
8687
for module in self.modules:
@@ -96,23 +97,13 @@ def __init__(
9697
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
9798
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
9899

99-
group_id_key = "_".join(sorted(param_names))
100-
self._disk_offload_group_id = hashlib.md5(group_id_key.encode()).hexdigest()[:8]
101-
102100
self.cpu_param_dict = {}
103101
else:
104102
self.cpu_param_dict = self._init_cpu_param_dict()
105103

106104
if self.stream is None and self.record_stream:
107105
raise ValueError("`record_stream` cannot be True when `stream` is None.")
108106

109-
@property
110-
def _disk_offload_file_path(self):
111-
if self.offload_to_disk_path:
112-
return os.path.join(self.offload_to_disk_path, f"group_{self._disk_offload_group_id}.safetensors")
113-
114-
return None
115-
116107
def _init_cpu_param_dict(self):
117108
cpu_param_dict = {}
118109
if self.stream is None:
@@ -173,7 +164,7 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
173164

174165
def _onload_from_disk(self, current_stream):
175166
if self.stream is not None:
176-
loaded_cpu_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device="cpu")
167+
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
177168

178169
for key, tensor_obj in self.key_to_tensor.items():
179170
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
@@ -188,7 +179,7 @@ def _onload_from_disk(self, current_stream):
188179
onload_device = (
189180
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
190181
)
191-
loaded_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device=onload_device)
182+
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
192183
for key, tensor_obj in self.key_to_tensor.items():
193184
tensor_obj.data = loaded_tensors[key]
194185

@@ -231,10 +222,10 @@ def _offload_to_disk(self):
231222
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
232223
# we perform a write.
233224
# Check if the file has been saved in this session or if it already exists on disk.
234-
if not self._is_offloaded_to_disk and not os.path.exists(self._disk_offload_file_path):
235-
os.makedirs(os.path.dirname(self._disk_offload_file_path), exist_ok=True)
225+
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
226+
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
236227
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
237-
safetensors.torch.save_file(tensors_to_save, self._disk_offload_file_path)
228+
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
238229

239230
# The group is now considered offloaded to disk for the rest of the session.
240231
self._is_offloaded_to_disk = True

0 commit comments

Comments
 (0)