Skip to content

Commit 1d38a30

Browse files
committed
update
1 parent 1359348 commit 1d38a30

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import hashlib
1516
import os
1617
from contextlib import contextmanager, nullcontext
1718
from typing import Dict, List, Optional, Set, Tuple, Union
@@ -80,8 +81,6 @@ def __init__(
8081
self._is_offloaded_to_disk = False
8182

8283
if self.offload_to_disk_path:
83-
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
84-
8584
all_tensors = []
8685
for module in self.modules:
8786
all_tensors.extend(list(module.parameters()))
@@ -92,13 +91,24 @@ def __init__(
9291

9392
self.tensor_to_key = {tensor: f"tensor_{i}" for i, tensor in enumerate(all_tensors)}
9493
self.key_to_tensor = {v: k for k, v in self.tensor_to_key.items()}
94+
95+
keys_str = "_".join(sorted(self.key_to_tensor.keys()))
96+
self._disk_offload_group_id = hashlib.md5(keys_str.encode()).hexdigest()[:8]
97+
9598
self.cpu_param_dict = {}
9699
else:
97100
self.cpu_param_dict = self._init_cpu_param_dict()
98101

99102
if self.stream is None and self.record_stream:
100103
raise ValueError("`record_stream` cannot be True when `stream` is None.")
101104

105+
@property
106+
def _disk_offload_file_path(self):
107+
if self.offload_to_disk_path:
108+
return os.path.join(self.offload_to_disk_path, f"group_{self._disk_offload_group_id}.safetensors")
109+
110+
return None
111+
102112
def _init_cpu_param_dict(self):
103113
cpu_param_dict = {}
104114
if self.stream is None:
@@ -159,7 +169,7 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
159169

160170
def _onload_from_disk(self, current_stream):
161171
if self.stream is not None:
162-
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
172+
loaded_cpu_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device="cpu")
163173

164174
for key, tensor_obj in self.key_to_tensor.items():
165175
self.cpu_param_dict[tensor_obj] = loaded_cpu_tensors[key]
@@ -174,7 +184,7 @@ def _onload_from_disk(self, current_stream):
174184
onload_device = (
175185
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
176186
)
177-
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
187+
loaded_tensors = safetensors.torch.load_file(self._disk_offload_file_path, device=onload_device)
178188
for key, tensor_obj in self.key_to_tensor.items():
179189
tensor_obj.data = loaded_tensors[key]
180190

@@ -187,6 +197,11 @@ def _onload_from_memory(self, current_stream):
187197

188198
@torch.compiler.disable()
189199
def onload_(self):
200+
# Generate disk offload group ID if needed and not already set
201+
if self.offload_to_disk_path and self._disk_offload_group_id is None:
202+
keys_str = "_".join(sorted(self.key_to_tensor.keys()))
203+
self._disk_offload_group_id = hashlib.md5(keys_str.encode()).hexdigest()[:8]
204+
190205
torch_accelerator_module = (
191206
getattr(torch, torch.accelerator.current_accelerator().type)
192207
if hasattr(torch, "accelerator")

0 commit comments

Comments
 (0)