2020import  torch 
2121
2222from  ..utils  import  get_logger , is_accelerate_available 
23+ from  ..utils .import_utils  import  is_deepspeed_available , is_deepspeed_version 
2324from  .hooks  import  HookRegistry , ModelHook 
2425
2526
2627if  is_accelerate_available ():
2728    from  accelerate .hooks  import  AlignDevicesHook , CpuOffload 
2829    from  accelerate .utils  import  send_to_device 
2930
31+ if  is_deepspeed_available () and  is_deepspeed_version (">=" , "0.16" ):
32+     from  ..utils .state_dict_utils  import  _fast_aio_save 
3033
3134logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
3235
@@ -62,6 +65,7 @@ def __init__(
6265        low_cpu_mem_usage : bool  =  False ,
6366        onload_self : bool  =  True ,
6467        offload_to_disk_path : Optional [str ] =  None ,
68+         _enable_deepnvme_disk_offloading : Optional [bool ] =  False 
6569    ) ->  None :
6670        self .modules  =  modules 
6771        self .offload_device  =  offload_device 
@@ -80,7 +84,9 @@ def __init__(
8084        self ._is_offloaded_to_disk  =  False 
8185
8286        if  self .offload_to_disk_path :
83-             self .safetensors_file_path  =  os .path .join (self .offload_to_disk_path , f"group_{ id (self )}  )
87+             self ._enable_deepnvme_disk_offloading  =  _enable_deepnvme_disk_offloading 
88+             ext  =  ".pt"  if  _enable_deepnvme_disk_offloading  else  ".safetensors" 
89+             self .param_file_path  =  os .path .join (self .offload_to_disk_path , f"group_{ id (self )} { ext }  )
8490
8591            all_tensors  =  []
8692            for  module  in  self .modules :
@@ -153,8 +159,8 @@ def onload_(self):
153159
154160            with  context :
155161                if  self .stream  is  not None :
156-                     # Load to CPU, pin, and async copy to device for overlapping transfer and compute 
157-                     loaded_cpu_tensors  =  safetensors .torch .load_file (self .safetensors_file_path , device = "cpu" )
162+                     # Load to CPU from disk , pin, and async copy to device for overlapping transfer and compute 
163+                     loaded_cpu_tensors  =  safetensors .torch .load_file (self .param_file_path , device = "cpu" )
158164                    for  key , tensor_obj  in  self .key_to_tensor .items ():
159165                        pinned_tensor  =  loaded_cpu_tensors [key ].pin_memory ()
160166                        tensor_obj .data  =  pinned_tensor .to (self .onload_device , non_blocking = self .non_blocking )
@@ -165,7 +171,7 @@ def onload_(self):
165171                    onload_device  =  (
166172                        self .onload_device .type  if  isinstance (self .onload_device , torch .device ) else  self .onload_device 
167173                    )
168-                     loaded_tensors  =  safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
174+                     loaded_tensors  =  safetensors .torch .load_file (self .param_file_path , device = onload_device )
169175                    for  key , tensor_obj  in  self .key_to_tensor .items ():
170176                        tensor_obj .data  =  loaded_tensors [key ]
171177            return 
@@ -218,15 +224,18 @@ def offload_(self):
218224        if  self .offload_to_disk_path :
219225            # TODO: we can potentially optimize this code path by checking if the _all_ the desired 
220226            # safetensor files exist on the disk and if so, skip this step entirely, reducing IO 
221-             # overhead. Currently, we just check if the given `safetensors_file_path ` exists and if not 
227+             # overhead. Currently, we just check if the given `param_file_path ` exists and if not 
222228            # we perform a write. 
223229            # Check if the file has been saved in this session or if it already exists on disk. 
224-             if  not  self ._is_offloaded_to_disk  and  not  os .path .exists (self .safetensors_file_path ):
225-                 os .makedirs (os .path .dirname (self .safetensors_file_path ), exist_ok = True )
230+             if  not  self ._is_offloaded_to_disk  and  not  os .path .exists (self .param_file_path ):
231+                 os .makedirs (os .path .dirname (self .param_file_path ), exist_ok = True )
226232                tensors_to_save  =  {
227233                    key : tensor .data .to (self .offload_device ) for  tensor , key  in  self .tensor_to_key .items ()
228234                }
229-                 safetensors .torch .save_file (tensors_to_save , self .safetensors_file_path )
235+                 if  not  self ._enable_deepnvme_disk_offloading :
236+                     safetensors .torch .save_file (tensors_to_save , self .param_file_path )
237+                 else :
238+                     _fast_aio_save (tensors_to_save , self .param_file_path )
230239
231240            # The group is now considered offloaded to disk for the rest of the session. 
232241            self ._is_offloaded_to_disk  =  True 
@@ -426,6 +435,7 @@ def apply_group_offloading(
426435    record_stream : bool  =  False ,
427436    low_cpu_mem_usage : bool  =  False ,
428437    offload_to_disk_path : Optional [str ] =  None ,
438+     _enable_deepnvme_disk_offloading : Optional [bool ] =  False 
429439) ->  None :
430440    r""" 
431441    Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and 
@@ -531,6 +541,7 @@ def apply_group_offloading(
531541            stream = stream ,
532542            record_stream = record_stream ,
533543            low_cpu_mem_usage = low_cpu_mem_usage ,
544+             _enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading 
534545        )
535546    elif  offload_type  ==  "leaf_level" :
536547        _apply_group_offloading_leaf_level (
@@ -542,6 +553,7 @@ def apply_group_offloading(
542553            stream = stream ,
543554            record_stream = record_stream ,
544555            low_cpu_mem_usage = low_cpu_mem_usage ,
556+             _enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading 
545557        )
546558    else :
547559        raise  ValueError (f"Unsupported offload_type: { offload_type }  )
@@ -557,6 +569,7 @@ def _apply_group_offloading_block_level(
557569    record_stream : Optional [bool ] =  False ,
558570    low_cpu_mem_usage : bool  =  False ,
559571    offload_to_disk_path : Optional [str ] =  None ,
572+     _enable_deepnvme_disk_offloading : Optional [bool ] =  False 
560573) ->  None :
561574    r""" 
562575    This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to 
@@ -617,6 +630,7 @@ def _apply_group_offloading_block_level(
617630                record_stream = record_stream ,
618631                low_cpu_mem_usage = low_cpu_mem_usage ,
619632                onload_self = True ,
633+                 _enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading 
620634            )
621635            matched_module_groups .append (group )
622636            for  j  in  range (i , i  +  len (current_modules )):
@@ -651,6 +665,7 @@ def _apply_group_offloading_block_level(
651665        stream = None ,
652666        record_stream = False ,
653667        onload_self = True ,
668+         _enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading ,
654669    )
655670    if  stream  is  None :
656671        _apply_group_offloading_hook (module , unmatched_group , None )
@@ -667,6 +682,7 @@ def _apply_group_offloading_leaf_level(
667682    record_stream : Optional [bool ] =  False ,
668683    low_cpu_mem_usage : bool  =  False ,
669684    offload_to_disk_path : Optional [str ] =  None ,
685+     _enable_deepnvme_disk_offloading : Optional [bool ] =  False 
670686) ->  None :
671687    r""" 
672688    This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory 
@@ -717,6 +733,7 @@ def _apply_group_offloading_leaf_level(
717733            record_stream = record_stream ,
718734            low_cpu_mem_usage = low_cpu_mem_usage ,
719735            onload_self = True ,
736+             _enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading 
720737        )
721738        _apply_group_offloading_hook (submodule , group , None )
722739        modules_with_group_offloading .add (name )
@@ -764,6 +781,7 @@ def _apply_group_offloading_leaf_level(
764781            record_stream = record_stream ,
765782            low_cpu_mem_usage = low_cpu_mem_usage ,
766783            onload_self = True ,
784+             _enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading 
767785        )
768786        _apply_group_offloading_hook (parent_module , group , None )
769787
@@ -785,6 +803,7 @@ def _apply_group_offloading_leaf_level(
785803            record_stream = False ,
786804            low_cpu_mem_usage = low_cpu_mem_usage ,
787805            onload_self = True ,
806+             _enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading 
788807        )
789808        _apply_lazy_group_offloading_hook (module , unmatched_group , None )
790809
0 commit comments