1212# See the License for the specific language governing permissions and 
1313# limitations under the License. 
1414
15+ import  os 
1516from  contextlib  import  contextmanager , nullcontext 
1617from  typing  import  Dict , List , Optional , Set , Tuple , Union 
17- import  os 
1818
19- import  torch 
2019import  safetensors .torch 
20+ import  torch 
21+ 
2122from  ..utils  import  get_logger , is_accelerate_available 
2223from  .hooks  import  HookRegistry , ModelHook 
2324
@@ -165,9 +166,10 @@ def onload_(self):
165166                            tensor_obj .data .record_stream (current_stream )
166167                else :
167168                    # Load directly to the target device (synchronous) 
168-                     loaded_tensors  =  safetensors . torch . load_file (
169-                         self .safetensors_file_path ,  device = self .onload_device 
169+                     onload_device  =  (
170+                         self .onload_device . type   if   isinstance ( self . onload_device ,  torch . device )  else   self .onload_device 
170171                    )
172+                     loaded_tensors  =  safetensors .torch .load_file (self .safetensors_file_path , device = onload_device )
171173                    for  key , tensor_obj  in  self .key_to_tensor .items ():
172174                        tensor_obj .data  =  loaded_tensors [key ]
173175            return 
@@ -265,16 +267,12 @@ class GroupOffloadingHook(ModelHook):
265267
266268    _is_stateful  =  False 
267269
268-     def  __init__ (
269-         self ,
270-         group : ModuleGroup ,
271-         next_group : Optional [ModuleGroup ] =  None 
272-     ) ->  None :
270+     def  __init__ (self , group : ModuleGroup , next_group : Optional [ModuleGroup ] =  None ) ->  None :
273271        self .group  =  group 
274272        self .next_group  =  next_group 
275273        # map param/buffer name -> file path 
276-         self .param_to_path : Dict [str ,str ] =  {}
277-         self .buffer_to_path : Dict [str ,str ] =  {}
274+         self .param_to_path : Dict [str ,  str ] =  {}
275+         self .buffer_to_path : Dict [str ,  str ] =  {}
278276
279277    def  initialize_hook (self , module : torch .nn .Module ) ->  torch .nn .Module :
280278        if  self .group .offload_leader  ==  module :
@@ -516,7 +514,6 @@ def apply_group_offloading(
516514            stream  =  torch .Stream ()
517515        else :
518516            raise  ValueError ("Using streams for data transfer requires a CUDA device, or an Intel XPU device." )
519- 
520517    if  offload_to_disk  and  offload_path  is  None :
521518        raise  ValueError ("`offload_path` must be set when `offload_to_disk=True`." )
522519
@@ -899,4 +896,4 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
899896    for  submodule  in  module .modules ():
900897        if  hasattr (submodule , "_diffusers_hook" ) and  submodule ._diffusers_hook .get_hook (_GROUP_OFFLOADING ) is  not None :
901898            return  submodule ._diffusers_hook .get_hook (_GROUP_OFFLOADING ).group .onload_device 
902-     raise  ValueError ("Group offloading is not enabled for the provided module." )
899+     raise  ValueError ("Group offloading is not enabled for the provided module." )
0 commit comments