1313# limitations under the License. 
1414
1515from  contextlib  import  nullcontext 
16- from  typing  import  Dict , List , Optional , Tuple 
16+ from  typing  import  Dict , List , Optional , Set ,  Tuple 
1717
1818import  torch 
1919from  accelerate .utils  import  send_to_device 
@@ -284,6 +284,8 @@ def apply_group_offloading(
284284        _apply_group_offloading_leaf_level (
285285            module , offload_device , onload_device , force_offload , non_blocking , stream = stream 
286286        )
287+     else :
288+         raise  ValueError (f"Unsupported offload_type: { offload_type }  )
287289
288290
289291def  _apply_group_offloading_block_level (
@@ -325,12 +327,15 @@ def _apply_group_offloading_block_level(
325327        cpu_param_dict  =  {param : param .data  for  param  in  module .parameters ()}
326328
327329    # Create module groups for ModuleList and Sequential blocks 
330+     modules_with_group_offloading  =  set ()
328331    unmatched_modules  =  []
329332    matched_module_groups  =  []
330333    for  name , submodule  in  module .named_children ():
331334        if  not  isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
332335            unmatched_modules .append ((name , submodule ))
336+             modules_with_group_offloading .add (name )
333337            continue 
338+ 
334339        for  i  in  range (0 , len (submodule ), num_blocks_per_group ):
335340            current_modules  =  submodule [i  : i  +  num_blocks_per_group ]
336341            group  =  ModuleGroup (
@@ -345,6 +350,8 @@ def _apply_group_offloading_block_level(
345350                onload_self = stream  is  None ,
346351            )
347352            matched_module_groups .append (group )
353+             for  j  in  range (i , i  +  len (current_modules )):
354+                 modules_with_group_offloading .add (f"{ name } { j }  )
348355
349356    # Apply group offloading hooks to the module groups 
350357    for  i , group  in  enumerate (matched_module_groups ):
@@ -359,15 +366,10 @@ def _apply_group_offloading_block_level(
359366    # Parameters and Buffers of the top-level module need to be offloaded/onloaded separately 
360367    # when the forward pass of this module is called. This is because the top-level module is not 
361368    # part of any group (as doing so would lead to no VRAM savings). 
362-     parameters  =  []
363-     for  name , parameter  in  module .named_parameters (recurse = False ):
364-         if  not  any (name .startswith (unmatched_name ) for  unmatched_name , _  in  unmatched_modules ):
365-             parameters .append (parameter )
366- 
367-     buffers  =  []
368-     for  name , buffer  in  module .named_buffers (recurse = False ):
369-         if  not  any (name .startswith (unmatched_name ) for  unmatched_name , _  in  unmatched_modules ):
370-             buffers .append (buffer )
369+     parameters  =  _gather_parameters_with_no_group_offloading_parent (module , modules_with_group_offloading )
370+     buffers  =  _gather_buffers_with_no_group_offloading_parent (module , modules_with_group_offloading )
371+     parameters  =  [param  for  _ , param  in  parameters ]
372+     buffers  =  [buffer  for  _ , buffer  in  buffers ]
371373
372374    # Create a group for the unmatched submodules of the top-level module so that they are on the correct 
373375    # device when the forward pass is called. 
@@ -428,7 +430,8 @@ def _apply_group_offloading_leaf_level(
428430        cpu_param_dict  =  {param : param .data  for  param  in  module .parameters ()}
429431
430432    # Create module groups for leaf modules and apply group offloading hooks 
431-     for  submodule  in  module .modules ():
433+     modules_with_group_offloading  =  set ()
434+     for  name , submodule  in  module .named_modules ():
432435        if  not  isinstance (submodule , _SUPPORTED_PYTORCH_LAYERS ):
433436            continue 
434437        group  =  ModuleGroup (
@@ -443,38 +446,65 @@ def _apply_group_offloading_leaf_level(
443446            onload_self = True ,
444447        )
445448        _apply_group_offloading_hook (submodule , group , True , None )
449+         modules_with_group_offloading .add (name )
446450
447451    # Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass 
448452    # of the module is called 
449-     parameters  =  []
450-     buffers  =  []
451453    module_dict  =  dict (module .named_modules ())
454+     parameters  =  _gather_parameters_with_no_group_offloading_parent (module , modules_with_group_offloading )
455+     buffers  =  _gather_buffers_with_no_group_offloading_parent (module , modules_with_group_offloading )
456+ 
457+     # Find closest module parent for each parameter and buffer, and attach group hooks 
458+     common_kwargs  =  {
459+         "modules" : [],
460+         "offload_device" : offload_device ,
461+         "onload_device" : onload_device ,
462+         "non_blocking" : non_blocking ,
463+         "stream" : stream ,
464+         "cpu_param_dict" : cpu_param_dict ,
465+         "onload_self" : True ,
466+     }
467+ 
468+     for  name , param  in  parameters :
469+         parent_name  =  _find_parent_module_in_module_dict (name , module_dict )
470+         parent_module  =  module_dict [parent_name ]
471+         logger .info (f"TODO: REMOVETHIS Found parameter { name } { parent_name }  )
472+         assert  getattr (parent_module , "_diffusers_hook" , None ) is  None 
473+         group  =  ModuleGroup (
474+             offload_leader = parent_module ,
475+             onload_leader = parent_module ,
476+             parameters = [param ],
477+             buffers = None ,
478+             ** common_kwargs ,
479+         )
480+         _apply_group_offloading_hook (parent_module , group , True , None )
452481
453-     for  name , parameter  in  module . named_parameters () :
454-         atoms  =  name . split ( "." )
455-         parent_name  =  "." . join ( atoms [: - 1 ]) 
456-         if   parent_name   in   module_dict   and   isinstance ( module_dict [ parent_name ],  _SUPPORTED_PYTORCH_LAYERS ): 
457-              continue 
458-         parameters . append ( parameter ) 
459- 
460-     for   name ,  buffer   in   module . named_buffers (): 
461-         atoms   =   name . split ( "." ) 
462-         parent_name   =   "." . join ( atoms [: - 1 ]) 
463-         if   parent_name   in   module_dict   and   isinstance ( module_dict [ parent_name ],  _SUPPORTED_PYTORCH_LAYERS ): 
464-              continue 
465-         buffers . append ( buffer )
482+     for  name , buffer  in  buffers :
483+         parent_name  =  _find_parent_module_in_module_dict ( name ,  module_dict )
484+         parent_module  =  module_dict [ parent_name ] 
485+         logger . info ( f"TODO: REMOVETHIS Found buffer  { name }  with parent module  { parent_name } " ) 
486+         assert   getattr ( parent_module ,  "_diffusers_hook" ,  None )  is   None 
487+         group   =   ModuleGroup ( 
488+              offload_leader = parent_module , 
489+              onload_leader = parent_module , 
490+              parameters = None , 
491+              buffers = [ buffer ], 
492+              ** common_kwargs , 
493+         ) 
494+         _apply_group_offloading_hook ( parent_module ,  group ,  True ,  None )
466495
496+     # This is a dummy group that will handle lazy prefetching from the top-level module to the first leaf module 
467497    unmatched_group  =  ModuleGroup (
468498        modules = [],
469499        offload_device = offload_device ,
470500        onload_device = onload_device ,
471501        offload_leader = module ,
472502        onload_leader = module ,
473-         parameters = parameters ,
474-         buffers = buffers ,
503+         parameters = None ,
504+         buffers = None ,
475505        non_blocking = False ,
476506        stream = None ,
477-         cpu_param_dict = cpu_param_dict ,
507+         cpu_param_dict = None ,
478508        onload_self = True ,
479509    )
480510
@@ -509,3 +539,55 @@ def _apply_lazy_group_offloading_hook(
509539    registry  =  HookRegistry .check_if_exists_or_initialize (module )
510540    registry .register_hook (hook , _GROUP_OFFLOADING )
511541    registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
542+ 
543+ 
544+ def  _gather_parameters_with_no_group_offloading_parent (
545+     module : torch .nn .Module , modules_with_group_offloading : Set [str ]
546+ ) ->  List [torch .nn .Parameter ]:
547+     parameters  =  []
548+     for  name , parameter  in  module .named_parameters ():
549+         has_parent_with_group_offloading  =  False 
550+         atoms  =  name .split ("." )
551+ 
552+         while  len (atoms ) >  0 :
553+             parent_name  =  "." .join (atoms )
554+             if  parent_name  in  modules_with_group_offloading :
555+                 has_parent_with_group_offloading  =  True 
556+                 break 
557+             atoms .pop ()
558+ 
559+         if  not  has_parent_with_group_offloading :
560+             logger .info (f"TODO: REMOVETHIS Found parameter { name }  )
561+             parameters .append ((name , parameter ))
562+     return  parameters 
563+ 
564+ 
565+ def  _gather_buffers_with_no_group_offloading_parent (
566+     module : torch .nn .Module , modules_with_group_offloading : Set [str ]
567+ ) ->  List [torch .Tensor ]:
568+     buffers  =  []
569+     for  name , buffer  in  module .named_buffers ():
570+         has_parent_with_group_offloading  =  False 
571+         atoms  =  name .split ("." )
572+ 
573+         while  len (atoms ) >  0 :
574+             parent_name  =  "." .join (atoms )
575+             if  parent_name  in  modules_with_group_offloading :
576+                 has_parent_with_group_offloading  =  True 
577+                 break 
578+             atoms .pop ()
579+ 
580+         if  not  has_parent_with_group_offloading :
581+             logger .info (f"TODO: REMOVETHIS Found buffer { name }  )
582+             buffers .append ((name , buffer ))
583+     return  buffers 
584+ 
585+ 
586+ def  _find_parent_module_in_module_dict (name : str , module_dict : Dict [str , torch .nn .Module ]) ->  str :
587+     atoms  =  name .split ("." )
588+     while  len (atoms ) >  0 :
589+         parent_name  =  "." .join (atoms )
590+         if  parent_name  in  module_dict :
591+             return  parent_name 
592+         atoms .pop ()
593+     return  "" 
0 commit comments