@@ -81,6 +81,7 @@ def __init__(
8181 self ._is_offloaded_to_disk = False
8282
8383 if self .offload_to_disk_path :
84+ self .safetensors_file_path = os .path .join (self .offload_to_disk_path , f"group_{ id (self )} .safetensors" )
8485 all_tensors = []
8586 param_names = []
8687 for module in self .modules :
@@ -96,23 +97,13 @@ def __init__(
9697 self .tensor_to_key = {tensor : f"tensor_{ i } " for i , tensor in enumerate (all_tensors )}
9798 self .key_to_tensor = {v : k for k , v in self .tensor_to_key .items ()}
9899
99- group_id_key = "_" .join (sorted (param_names ))
100- self ._disk_offload_group_id = hashlib .md5 (group_id_key .encode ()).hexdigest ()[:8 ]
101-
102100 self .cpu_param_dict = {}
103101 else :
104102 self .cpu_param_dict = self ._init_cpu_param_dict ()
105103
106104 if self .stream is None and self .record_stream :
107105 raise ValueError ("`record_stream` cannot be True when `stream` is None." )
108106
109- @property
110- def _disk_offload_file_path (self ):
111- if self .offload_to_disk_path :
112- return os .path .join (self .offload_to_disk_path , f"group_{ self ._disk_offload_group_id } .safetensors" )
113-
114- return None
115-
116107 def _init_cpu_param_dict (self ):
117108 cpu_param_dict = {}
118109 if self .stream is None :
@@ -173,7 +164,7 @@ def _process_tensors_from_modules(self, pinned_memory=None, current_stream=None)
173164
174165 def _onload_from_disk (self , current_stream ):
175166 if self .stream is not None :
176- loaded_cpu_tensors = safetensors .torch .load_file (self ._disk_offload_file_path , device = "cpu" )
167+ loaded_cpu_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
177168
178169 for key , tensor_obj in self .key_to_tensor .items ():
179170 self .cpu_param_dict [tensor_obj ] = loaded_cpu_tensors [key ]
@@ -188,7 +179,7 @@ def _onload_from_disk(self, current_stream):
188179 onload_device = (
189180 self .onload_device .type if isinstance (self .onload_device , torch .device ) else self .onload_device
190181 )
191- loaded_tensors = safetensors .torch .load_file (self ._disk_offload_file_path , device = onload_device )
182+ loaded_tensors = safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
192183 for key , tensor_obj in self .key_to_tensor .items ():
193184 tensor_obj .data = loaded_tensors [key ]
194185
@@ -231,10 +222,10 @@ def _offload_to_disk(self):
231222 # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
232223 # we perform a write.
233224 # Check if the file has been saved in this session or if it already exists on disk.
234- if not self ._is_offloaded_to_disk and not os .path .exists (self ._disk_offload_file_path ):
235- os .makedirs (os .path .dirname (self ._disk_offload_file_path ), exist_ok = True )
225+ if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
226+ os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
236227 tensors_to_save = {key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()}
237- safetensors .torch .save_file (tensors_to_save , self ._disk_offload_file_path )
228+ safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
238229
239230 # The group is now considered offloaded to disk for the rest of the session.
240231 self ._is_offloaded_to_disk = True
0 commit comments