1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import hashlib
1516import os
1617from contextlib import contextmanager , nullcontext
1718from typing import Dict , List , Optional , Set , Tuple , Union
@@ -80,8 +81,6 @@ def __init__(
8081 self ._is_offloaded_to_disk = False
8182
8283 if self .offload_to_disk_path :
83- self .safetensors_file_path = os .path .join (self .offload_to_disk_path , f"group_{ id (self )} .safetensors" )
84-
8584 all_tensors = []
8685 for module in self .modules :
8786 all_tensors .extend (list (module .parameters ()))
@@ -92,13 +91,24 @@ def __init__(
9291
9392 self .tensor_to_key = {tensor : f"tensor_{ i } " for i , tensor in enumerate (all_tensors )}
9493 self .key_to_tensor = {v : k for k , v in self .tensor_to_key .items ()}
94+
95+ keys_str = "_" .join (sorted (self .key_to_tensor .keys ()))
96+ self ._disk_offload_group_id = hashlib .md5 (keys_str .encode ()).hexdigest ()[:8 ]
97+
9598 self .cpu_param_dict = {}
9699 else :
97100 self .cpu_param_dict = self ._init_cpu_param_dict ()
98101
99102 if self .stream is None and self .record_stream :
100103 raise ValueError ("`record_stream` cannot be True when `stream` is None." )
101104
105+ @property
106+ def _disk_offload_file_path (self ):
107+ if self .offload_to_disk_path :
108+ return os .path .join (self .offload_to_disk_path , f"group_{ self ._disk_offload_group_id } .safetensors" )
109+
110+ return None
111+
102112 def _init_cpu_param_dict (self ):
103113 cpu_param_dict = {}
104114 if self .stream is None :
@@ -159,7 +169,7 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
159169
160170 def _onload_from_disk (self , current_stream ):
161171 if self .stream is not None :
162- loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
172+ loaded_cpu_tensors = safetensors .torch .load_file (self ._disk_offload_file_path , device = "cpu" )
163173
164174 for key , tensor_obj in self .key_to_tensor .items ():
165175 self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
@@ -174,7 +184,7 @@ def _onload_from_disk(self, current_stream):
174184 onload_device = (
175185 self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
176186 )
177- loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
187+ loaded_tensors = safetensors .torch .load_file (self ._disk_offload_file_path , device = onload_device )
178188 for key , tensor_obj in self .key_to_tensor .items ():
179189 tensor_obj .data = loaded_tensors [key ]
180190
@@ -187,6 +197,11 @@ def _onload_from_memory(self, current_stream):
187197
188198 @torch .compiler .disable ()
189199 def onload_ (self ):
200+ # Generate disk offload group ID if needed and not already set
201+ if self .offload_to_disk_path and self ._disk_offload_group_id is None :
202+ keys_str = "_" .join (sorted (self .key_to_tensor .keys ()))
203+ self ._disk_offload_group_id = hashlib .md5 (keys_str .encode ()).hexdigest ()[:8 ]
204+
190205 torch_accelerator_module = (
191206 getattr (torch , torch .accelerator .current_accelerator ().type )
192207 if hasattr (torch , "accelerator" )
0 commit comments