@@ -455,41 +455,40 @@ def _apply_group_offloading_leaf_level(
455455    buffers  =  _gather_buffers_with_no_group_offloading_parent (module , modules_with_group_offloading )
456456
457457    # Find closest module parent for each parameter and buffer, and attach group hooks 
458-     common_kwargs  =  {
459-         "modules" : [],
460-         "offload_device" : offload_device ,
461-         "onload_device" : onload_device ,
462-         "non_blocking" : non_blocking ,
463-         "stream" : stream ,
464-         "cpu_param_dict" : cpu_param_dict ,
465-         "onload_self" : True ,
466-     }
467- 
458+     parent_to_parameters  =  {}
468459    for  name , param  in  parameters :
469460        parent_name  =  _find_parent_module_in_module_dict (name , module_dict )
470-         parent_module  =  module_dict [parent_name ]
471-         logger .info (f"TODO: REMOVETHIS Found parameter { name } { parent_name }  )
472-         assert  getattr (parent_module , "_diffusers_hook" , None ) is  None 
473-         group  =  ModuleGroup (
474-             offload_leader = parent_module ,
475-             onload_leader = parent_module ,
476-             parameters = [param ],
477-             buffers = None ,
478-             ** common_kwargs ,
479-         )
480-         _apply_group_offloading_hook (parent_module , group , True , None )
461+         if  parent_name  in  parent_to_parameters :
462+             parent_to_parameters [parent_name ].append (param )
463+         else :
464+             parent_to_parameters [parent_name ] =  [param ]
481465
466+     parent_to_buffers  =  {}
482467    for  name , buffer  in  buffers :
483468        parent_name  =  _find_parent_module_in_module_dict (name , module_dict )
484-         parent_module  =  module_dict [parent_name ]
485-         logger .info (f"TODO: REMOVETHIS Found buffer { name } { parent_name }  )
469+         if  parent_name  in  parent_to_buffers :
470+             parent_to_buffers [parent_name ].append (buffer )
471+         else :
472+             parent_to_buffers [parent_name ] =  [buffer ]
473+ 
474+     parent_names  =  set (parent_to_parameters .keys ()) |  set (parent_to_buffers .keys ())
475+     for  name  in  parent_names :
476+         parameters  =  parent_to_parameters .get (name , [])
477+         buffers  =  parent_to_buffers .get (name , [])
478+         parent_module  =  module_dict [name ]
486479        assert  getattr (parent_module , "_diffusers_hook" , None ) is  None 
487480        group  =  ModuleGroup (
481+             modules = [],
482+             offload_device = offload_device ,
483+             onload_device = onload_device ,
488484            offload_leader = parent_module ,
489485            onload_leader = parent_module ,
490-             parameters = None ,
491-             buffers = [buffer ],
492-             ** common_kwargs ,
486+             parameters = parameters ,
487+             buffers = buffers ,
488+             non_blocking = non_blocking ,
489+             stream = stream ,
490+             cpu_param_dict = cpu_param_dict ,
491+             onload_self = True ,
493492        )
494493        _apply_group_offloading_hook (parent_module , group , True , None )
495494
@@ -557,7 +556,6 @@ def _gather_parameters_with_no_group_offloading_parent(
557556            atoms .pop ()
558557
559558        if  not  has_parent_with_group_offloading :
560-             logger .info (f"TODO: REMOVETHIS Found parameter { name }  )
561559            parameters .append ((name , parameter ))
562560    return  parameters 
563561
@@ -578,7 +576,6 @@ def _gather_buffers_with_no_group_offloading_parent(
578576            atoms .pop ()
579577
580578        if  not  has_parent_with_group_offloading :
581-             logger .info (f"TODO: REMOVETHIS Found buffer { name }  )
582579            buffers .append ((name , buffer ))
583580    return  buffers 
584581
0 commit comments