File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -216,13 +216,20 @@ def onload_(self):
216216 def offload_ (self ):
217217 r"""Offloads the group of modules to the offload_device."""
218218 if self .offload_to_disk_path :
219- if not self ._is_offloaded_to_disk :
219+ # TODO: we can potentially optimize this code path by checking if the _all_ the desired
220+ # safetensor files exist on the disk and if so, skip this step entirely, reducing IO
221+ # overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
222+ # we perform a write.
223+ # Check if the file has been saved in this session or if it already exists on disk.
224+ if not self ._is_offloaded_to_disk and not os .path .exists (self .safetensors_file_path ):
220225 os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
221226 tensors_to_save = {
222227 key : tensor .data .to (self .offload_device ) for tensor , key in self .tensor_to_key .items ()
223228 }
224229 safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
225- self ._is_offloaded_to_disk = True
230+
231+ # The group is now considered offloaded to disk for the rest of the session.
232+ self ._is_offloaded_to_disk = True
226233
227234 for tensor_obj in self .tensor_to_key .keys ():
228235 tensor_obj .data = torch .empty_like (tensor_obj .data , device = self .offload_device )
You can’t perform that action at this time.
0 commit comments