2121from  .hooks  import  HookRegistry , ModelHook 
2222
2323
24- logger  =  get_logger (__name__ ) # pylint: disable=invalid-name 
24+ logger  =  get_logger (__name__ )   # pylint: disable=invalid-name 
2525
2626
2727class  ModuleGroup :
@@ -32,12 +32,16 @@ def __init__(
3232        onload_device : torch .device ,
3333        offload_leader : torch .nn .Module ,
3434        onload_leader : Optional [torch .nn .Module ] =  None ,
35+         parameters : Optional [List [torch .nn .Parameter ]] =  None ,
36+         buffers : Optional [List [torch .Tensor ]] =  None ,
3537    ) ->  None :
3638        self .modules  =  modules 
3739        self .offload_device  =  offload_device 
3840        self .onload_device  =  onload_device 
3941        self .offload_leader  =  offload_leader 
4042        self .onload_leader  =  onload_leader 
43+         self .parameters  =  parameters 
44+         self .buffers  =  buffers 
4145
4246
4347class  GroupOffloadingHook (ModelHook ):
@@ -64,13 +68,15 @@ def __init__(
6468        stream : Optional [torch .cuda .Stream ] =  None ,
6569        next_group : Optional [ModuleGroup ] =  None ,
6670        cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] =  None ,
71+         onload_self : bool  =  False ,
6772    ) ->  None :
6873        self .group  =  group 
6974        self .offload_on_init  =  offload_on_init 
7075        self .non_blocking  =  non_blocking 
7176        self .stream  =  stream 
7277        self .next_group  =  next_group 
7378        self .cpu_param_dict  =  cpu_param_dict 
79+         self .onload_self  =  onload_self 
7480
7581    def  initialize_hook (self , module : torch .nn .Module ) ->  torch .nn .Module :
7682        if  self .offload_on_init :
@@ -100,9 +106,16 @@ def onload_(self, module: torch.nn.Module) -> None:
100106                with  torch .cuda .stream (self .stream ):
101107                    for  group_module  in  self .next_group .modules :
102108                        group_module .to (self .next_group .onload_device , non_blocking = True )
103-             else :
109+ 
110+             if  self .stream  is  None  or  self .onload_self :
104111                for  group_module  in  self .group .modules :
105112                    group_module .to (self .group .onload_device , non_blocking = self .non_blocking )
113+                 if  self .group .parameters  is  not None :
114+                     for  param  in  self .group .parameters :
115+                         param .data  =  param .data .to (self .group .onload_device , non_blocking = self .non_blocking )
116+                 if  self .group .buffers  is  not None :
117+                     for  buffer  in  self .group .buffers :
118+                         buffer .data  =  buffer .data .to (self .group .onload_device , non_blocking = self .non_blocking )
106119
107120    def  offload_ (self , module : torch .nn .Module ) ->  None :
108121        if  self .group .offload_leader  ==  module :
@@ -113,6 +126,13 @@ def offload_(self, module: torch.nn.Module) -> None:
113126            else :
114127                for  group_module  in  self .group .modules :
115128                    group_module .to (self .group .offload_device , non_blocking = self .non_blocking )
129+                 if  self .group .parameters  is  not None :
130+                     for  param  in  self .group .parameters :
131+                         param .data  =  param .data .to (self .group .offload_device , non_blocking = self .non_blocking )
132+                 if  self .group .buffers  is  not None :
133+                     for  buffer  in  self .group .buffers :
134+                         buffer .data  =  buffer .data .to (self .group .offload_device , non_blocking = self .non_blocking )
135+ 
116136                # TODO: do we need to sync here because of GPU->CPU transfer? 
117137                if  self .non_blocking  and  self .group .offload_device .type  ==  "cpu" :
118138                    torch .cpu .synchronize ()
@@ -128,9 +148,9 @@ def apply_group_offloading(
128148    non_blocking : bool  =  False ,
129149    cuda_stream : bool  =  False ,
130150) ->  None :
131-     #  stream = None
132-     #  if cuda_stream:
133-     #      stream = torch.cuda.Stream()
151+     stream  =  None 
152+     if  cuda_stream :
153+         stream  =  torch .cuda .Stream ()
134154    if  offload_group_patterns  ==  "modulelist_or_sequential" :
135155        if  num_blocks_per_group  is  None :
136156            raise  ValueError (
@@ -148,7 +168,7 @@ def apply_group_offloading(
148168        offload_group_patterns  =  _get_modulelist_or_sequential_group_patterns (module , num_blocks_per_group )
149169
150170    _apply_group_offloading_group_patterns (
151-         module , offload_group_patterns , offload_device , onload_device , force_offload , non_blocking 
171+         module , offload_group_patterns , offload_device , onload_device , force_offload , non_blocking ,  stream = stream 
152172    )
153173
154174
@@ -231,6 +251,7 @@ def _apply_group_offloading_group_patterns(
231251    onload_device : torch .device ,
232252    force_offload : bool ,
233253    non_blocking : bool ,
254+     stream : Optional [torch .cuda .Stream ] =  None ,
234255) ->  None :
235256    r""" 
236257    This function applies offloading to groups of modules based on the provided regex patterns. Each group of modules 
@@ -269,8 +290,17 @@ def _apply_group_offloading_group_patterns(
269290        non_blocking (`bool`): 
270291            If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation 
271292            and data transfer. 
293+         stream (`torch.cuda.Stream`, *optional*): 
294+             If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful 
295+             for overlapping computation and data transfer. 
272296    """ 
273297
298+     cpu_param_dict  =  None 
299+     if  stream  is  not None :
300+         for  param  in  module .parameters ():
301+             param .data  =  param .data .cpu ().pin_memory ()
302+         cpu_param_dict  =  {param : param .data  for  param  in  module .parameters ()}
303+ 
274304    per_group_modules  =  [[] for  _  in  range (len (offload_group_patterns ))]
275305    per_group_offload_leaders  =  [None ] *  len (offload_group_patterns )
276306    per_group_onload_leaders  =  [None ] *  len (offload_group_patterns )
@@ -280,20 +310,20 @@ def _apply_group_offloading_group_patterns(
280310    offload_leader_patterns  =  [pattern [1 ] for  pattern  in  offload_group_patterns ]
281311    onload_leader_patterns  =  [pattern [2 ] for  pattern  in  offload_group_patterns ]
282312
283-     for  name , module  in  module .named_modules ():
284-         if  name .count ("." ) >  1 :
313+     for  name , submodule  in  module .named_modules ():
314+         if  name   ==   ""   or   name .count ("." ) >  1 :
285315            # We only want the layers that are top-level in the module (encompass all the other submodules) 
286316            # for enabling offloading. This method is specifically targeted for diffusers format models, 
287317            # so we can ignore submodules. 
288318            # TODO(aryan): This is not the case and is just a workaround to make the benchmark code work 
289319            # for now. We need to support the arbitrary nesting of modules here. 
290320            continue 
291-         num_matches  =  0 
292321
293322        # Check if the module matches any of the offload group patterns 
323+         num_matches  =  0 
294324        for  i , pattern  in  enumerate (group_patterns ):
295325            if  re .search (pattern , name ) is  not None :
296-                 per_group_modules [i ].append (module )
326+                 per_group_modules [i ].append (submodule )
297327                num_matches  +=  1 
298328
299329        # Check if the module matches any of the offload leader patterns 
@@ -303,7 +333,7 @@ def _apply_group_offloading_group_patterns(
303333                    raise  ValueError (
304334                        f"Module { name }  
305335                    )
306-                 per_group_offload_leaders [i ] =  module 
336+                 per_group_offload_leaders [i ] =  submodule 
307337
308338        # Check if the module matches any of the onload leader patterns 
309339        for  i , pattern  in  enumerate (onload_leader_patterns ):
@@ -314,16 +344,17 @@ def _apply_group_offloading_group_patterns(
314344                    raise  ValueError (
315345                        f"Module { name }  
316346                    )
317-                 per_group_onload_leaders [i ] =  module 
347+                 per_group_onload_leaders [i ] =  submodule 
318348
319349        if  num_matches  ==  0 :
320-             unmatched_group_modules .append (module )
350+             unmatched_group_modules .append (( name ,  submodule ) )
321351        elif  num_matches  >  1 :
322352            raise  ValueError (
323353                f"Module { name }  
324354            )
325355
326356    # Handle modules that matched patterns 
357+     groups  =  []
327358    for  i  in  range (len (per_group_modules )):
328359        if  per_group_offload_leaders [i ] is  None :
329360            raise  ValueError (
@@ -336,21 +367,40 @@ def _apply_group_offloading_group_patterns(
336367            offload_leader = per_group_offload_leaders [i ],
337368            onload_leader = per_group_onload_leaders [i ],
338369        )
339-         _apply_group_offloading (group , force_offload , non_blocking )
340- 
341-     # Handle modules that did not match patterns 
342-     for  module  in  unmatched_group_modules :
343-         group  =  ModuleGroup ([module ], offload_device , onload_device , offload_leader = module , onload_leader = module )
344-         _apply_group_offloading (group , force_offload , non_blocking )
345- 
346-     # TODO(aryan): When you add stream support, this may need to be put in an if-branch 
347-     # Always keep parameters and buffers on onload_device 
348-     for  name , param  in  module .named_parameters (recurse = False ):
349-         if  torch .is_tensor (param .data ):
350-             param .data  =  param .data .to (onload_device )
370+         groups .append (group )
371+ 
372+     for  i  in  range (len (groups )):
373+         next_group  =  groups [i  +  1 ] if  i  +  1  <  len (groups ) and  stream  is  not None  else  None 
374+         should_offload  =  force_offload  or  i  >  0 
375+         _apply_group_offloading (
376+             groups [i ], should_offload , non_blocking , stream , next_group , cpu_param_dict , onload_self = False 
377+         )
378+ 
379+     # Ignore parameters/buffers if they're already accounted for in unmatched_group_modules (for example, a nn.Linear 
380+     # in the top-level module will also be present in the named_parameters iterator) 
381+     parameters  =  []
382+     for  name , parameter  in  module .named_parameters (recurse = False ):
383+         if  not  any (name .startswith (unmatched_name ) for  unmatched_name , _  in  unmatched_group_modules ):
384+             parameters .append (parameter )
385+ 
386+     buffers  =  []
351387    for  name , buffer  in  module .named_buffers (recurse = False ):
352-         if  torch .is_tensor (buffer .data ):
353-             buffer .data  =  buffer .data .to (onload_device )
388+         if  not  any (name .startswith (unmatched_name ) for  unmatched_name , _  in  unmatched_group_modules ):
389+             buffers .append (buffer )
390+ 
391+     unmatched_modules  =  [module  for  _ , module  in  unmatched_group_modules ]
392+     unmatched_group  =  ModuleGroup (
393+         unmatched_modules ,
394+         offload_device ,
395+         onload_device ,
396+         offload_leader = module ,
397+         onload_leader = None ,
398+         parameters = parameters ,
399+         buffers = buffers ,
400+     )
401+     _apply_group_offloading (
402+         unmatched_group , force_offload , non_blocking , stream , groups [0 ], cpu_param_dict , onload_self = True 
403+     )
354404
355405
356406def  _apply_group_offloading (
@@ -360,9 +410,12 @@ def _apply_group_offloading(
360410    stream : Optional [torch .cuda .Stream ] =  None ,
361411    next_group : Optional [ModuleGroup ] =  None ,
362412    cpu_param_dict : Optional [Dict [torch .nn .Parameter , torch .Tensor ]] =  None ,
413+     onload_self : bool  =  False ,
363414) ->  None :
364415    for  module  in  group .modules :
365-         hook  =  GroupOffloadingHook (group , offload_on_init , non_blocking , stream , next_group , cpu_param_dict )
416+         hook  =  GroupOffloadingHook (
417+             group , offload_on_init , non_blocking , stream , next_group , cpu_param_dict , onload_self 
418+         )
366419        registry  =  HookRegistry .check_if_exists_or_initialize (module )
367420        registry .register_hook (hook , "group_offloading" )
368421
@@ -375,11 +428,11 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl
375428    blocks. The generated patterns can be used to create ModuleGroup objects which are offloaded and onloaded together. 
376429    """ 
377430    group_patterns  =  []
378-      
431+ 
379432    # We only want the layers that are top-level in the module (encompass all the other submodules) 
380433    # for enabling offloading. This method is specifically targeted for diffusers format models, 
381434    # so we can ignore everything but the children of this module. 
382-     for  name , submodule  in  module .children ():
435+     for  name , submodule  in  module .named_children ():
383436        if  not  isinstance (submodule , (torch .nn .ModuleList , torch .nn .Sequential )):
384437            continue 
385438        for  i  in  range (0 , len (submodule ), num_blocks_per_group ):
@@ -389,6 +442,6 @@ def _get_modulelist_or_sequential_group_patterns(module: torch.nn.Module, num_bl
389442            onload_leader_pattern  =  rf"{ name } { i }  
390443            offload_leader_pattern  =  rf"{ name } { i  +  num_modules  -  1 }  
391444            group_patterns .append ((pattern , offload_leader_pattern , onload_leader_pattern ))
392-      
445+ 
393446    logger .debug (f"Generated group patterns for apply_groupwise_offloading: { group_patterns }  )
394447    return  group_patterns 
0 commit comments