2525from  huggingface_hub  import  model_info 
2626from  huggingface_hub .constants  import  HF_HUB_OFFLINE 
2727
28+ from  ..hooks .group_offloading  import  _is_group_offload_enabled , _maybe_remove_and_reapply_group_offloading 
2829from  ..models .modeling_utils  import  ModelMixin , load_state_dict 
2930from  ..utils  import  (
3031    USE_PEFT_BACKEND ,
@@ -391,7 +392,9 @@ def _load_lora_into_text_encoder(
391392            adapter_name  =  get_adapter_name (text_encoder )
392393
393394        # <Unsafe code 
394-         is_model_cpu_offload , is_sequential_cpu_offload  =  _func_optionally_disable_offloading (_pipeline )
395+         is_model_cpu_offload , is_sequential_cpu_offload , is_group_offload  =  _func_optionally_disable_offloading (
396+             _pipeline 
397+         )
395398        # inject LoRA layers and load the state dict 
396399        # in transformers we automatically check whether the adapter name is already in use or not 
397400        text_encoder .load_adapter (
@@ -410,6 +413,10 @@ def _load_lora_into_text_encoder(
410413            _pipeline .enable_model_cpu_offload ()
411414        elif  is_sequential_cpu_offload :
412415            _pipeline .enable_sequential_cpu_offload ()
416+         elif  is_group_offload :
417+             for  component  in  _pipeline .components .values ():
418+                 if  isinstance (component , torch .nn .Module ):
419+                     _maybe_remove_and_reapply_group_offloading (component )
413420        # Unsafe code /> 
414421
415422    if  prefix  is  not None  and  not  state_dict :
@@ -433,30 +440,36 @@ def _func_optionally_disable_offloading(_pipeline):
433440
434441    Returns: 
435442        tuple: 
436-             A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True. 
443+             A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload`  is True. 
437444    """ 
438445    is_model_cpu_offload  =  False 
439446    is_sequential_cpu_offload  =  False 
447+     is_group_offload  =  False 
440448
441449    if  _pipeline  is  not None  and  _pipeline .hf_device_map  is  None :
442450        for  _ , component  in  _pipeline .components .items ():
443-             if  isinstance (component , nn .Module ) and  hasattr (component , "_hf_hook" ):
444-                 if  not  is_model_cpu_offload :
445-                     is_model_cpu_offload  =  isinstance (component ._hf_hook , CpuOffload )
446-                 if  not  is_sequential_cpu_offload :
447-                     is_sequential_cpu_offload  =  (
448-                         isinstance (component ._hf_hook , AlignDevicesHook )
449-                         or  hasattr (component ._hf_hook , "hooks" )
450-                         and  isinstance (component ._hf_hook .hooks [0 ], AlignDevicesHook )
451-                     )
451+             if  not  isinstance (component , nn .Module ):
452+                 continue 
453+             is_group_offload  =  is_group_offload  or  _is_group_offload_enabled (component )
454+             if  not  hasattr (component , "_hf_hook" ):
455+                 continue 
456+             is_model_cpu_offload  =  is_model_cpu_offload  or  isinstance (component ._hf_hook , CpuOffload )
457+             is_sequential_cpu_offload  =  is_sequential_cpu_offload  or  (
458+                 isinstance (component ._hf_hook , AlignDevicesHook )
459+                 or  hasattr (component ._hf_hook , "hooks" )
460+                 and  isinstance (component ._hf_hook .hooks [0 ], AlignDevicesHook )
461+             )
452462
453-                 logger .info (
454-                     "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." 
455-                 )
456-                 if  is_sequential_cpu_offload  or  is_model_cpu_offload :
457-                     remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
463+         if  is_sequential_cpu_offload  or  is_model_cpu_offload :
464+             logger .info (
465+                 "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." 
466+             )
467+             for  _ , component  in  _pipeline .components .items ():
468+                 if  not  isinstance (component , nn .Module ) or  not  hasattr (component , "_hf_hook" ):
469+                     continue 
470+             remove_hook_from_module (component , recurse = is_sequential_cpu_offload )
458471
459-     return  (is_model_cpu_offload , is_sequential_cpu_offload )
472+     return  (is_model_cpu_offload , is_sequential_cpu_offload ,  is_group_offload )
460473
461474
462475class  LoraBaseMixin :
0 commit comments