1313# limitations under the License. 
1414
1515from  contextlib  import  nullcontext 
16- from  typing  import  Dict , List , Optional 
16+ from  typing  import  Dict , List , Optional ,  Tuple 
1717
1818import  torch 
1919from  accelerate .utils  import  send_to_device 
2525logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
2626
2727
28+ _GROUP_OFFLOADING  =  "group_offloading" 
29+ _LAYER_EXECUTION_TRACKER  =  "layer_execution_tracker" 
30+ _LAZY_PREFETCH_GROUP_OFFLOADING  =  "lazy_prefetch_group_offloading" 
31+ 
32+ 
2833class  ModuleGroup :
2934    def  __init__ (
3035        self ,
@@ -99,6 +104,8 @@ class GroupOffloadingHook(ModelHook):
99104    group is responsible for onloading the current module group. 
100105    """ 
101106
107+     _is_stateful  =  False 
108+ 
102109    def  __init__ (
103110        self ,
104111        group : ModuleGroup ,
@@ -132,6 +139,85 @@ def post_forward(self, module: torch.nn.Module, output):
132139        return  output 
133140
134141
142+ class  LazyPrefetchGroupOffloadingHook (ModelHook ):
143+     _is_stateful  =  False 
144+ 
145+     def  __init__ (self ):
146+         self .execution_order : List [Tuple [str , torch .nn .Module ]] =  []
147+         self ._layer_execution_tracker_module_names  =  set ()
148+ 
149+     def  initialize_hook (self , module ):
150+         for  name , submodule  in  module .named_modules ():
151+             if  name  ==  ""  or  not  hasattr (submodule , "_diffusers_hook" ):
152+                 continue 
153+ 
154+             registry  =  HookRegistry .check_if_exists_or_initialize (submodule )
155+             group_offloading_hook  =  registry .get_hook (_GROUP_OFFLOADING )
156+ 
157+             if  group_offloading_hook  is  not None :
158+ 
159+                 def  make_execution_order_update_callback (current_name , current_submodule ):
160+                     def  callback ():
161+                         logger .debug (f"Adding { current_name }  )
162+                         self .execution_order .append ((current_name , current_submodule ))
163+ 
164+                     return  callback 
165+ 
166+                 layer_tracker_hook  =  LayerExecutionTrackerHook (make_execution_order_update_callback (name , submodule ))
167+                 registry .register_hook (layer_tracker_hook , _LAYER_EXECUTION_TRACKER )
168+                 self ._layer_execution_tracker_module_names .add (name )
169+ 
170+         return  module 
171+ 
172+     def  post_forward (self , module , output ):
173+         num_executed  =  len (self .execution_order )
174+         execution_order_module_names  =  {name  for  name , _  in  self .execution_order }
175+ 
176+         # Check if the two sets are equal 
177+         if  execution_order_module_names  !=  self ._layer_execution_tracker_module_names :
178+             unexecuted_layers  =  list (self ._layer_execution_tracker_module_names  -  execution_order_module_names )
179+             logger .warning (
180+                 "It seems like some layers were not executed during the forward pass. This may lead to problems when " 
181+                 "applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please " 
182+                 "make sure that all layers are executed during the forward pass. The following layers were not executed:\n " 
183+                 f"{ unexecuted_layers = }  
184+             )
185+ 
186+         base_module_registry  =  HookRegistry .check_if_exists_or_initialize (module )
187+         registries  =  [HookRegistry .check_if_exists_or_initialize (submodule ) for  _ , submodule  in  self .execution_order ]
188+ 
189+         for  i  in  range (num_executed ):
190+             registries [i ].remove_hook (_LAYER_EXECUTION_TRACKER )
191+ 
192+         base_module_registry .remove_hook (_LAZY_PREFETCH_GROUP_OFFLOADING )
193+ 
194+         group_offloading_hooks  =  [registry .get_hook (_GROUP_OFFLOADING ) for  registry  in  registries ]
195+         if  num_executed  >  0 :
196+             base_module_group_offloading_hook  =  base_module_registry .get_hook (_GROUP_OFFLOADING )
197+             base_module_group_offloading_hook .next_group  =  group_offloading_hooks [0 ].group 
198+             base_module_group_offloading_hook .next_group .onload_self  =  False 
199+ 
200+         for  i  in  range (num_executed  -  1 ):
201+             name1 , _  =  self .execution_order [i ]
202+             name2 , _  =  self .execution_order [i  +  1 ]
203+             logger .debug (f"Applying lazy prefetch group offloading from { name1 } { name2 }  )
204+             group_offloading_hooks [i ].next_group  =  group_offloading_hooks [i  +  1 ].group 
205+             group_offloading_hooks [i ].next_group .onload_self  =  False 
206+ 
207+         return  output 
208+ 
209+ 
210+ class  LayerExecutionTrackerHook (ModelHook ):
211+     _is_stateful  =  False 
212+ 
213+     def  __init__ (self , execution_order_update_callback ):
214+         self .execution_order_update_callback  =  execution_order_update_callback 
215+ 
216+     def  pre_forward (self , module , * args , ** kwargs ):
217+         self .execution_order_update_callback ()
218+         return  args , kwargs 
219+ 
220+ 
135221def  apply_group_offloading (
136222    module : torch .nn .Module ,
137223    offload_type : str  =  "block_level" ,
@@ -156,10 +242,10 @@ def apply_group_offloading(
156242        _apply_group_offloading_block_level (
157243            module , num_blocks_per_group , offload_device , onload_device , force_offload , non_blocking , stream = stream 
158244        )
159-     #  elif offload_type == "leaf_level":
160-     #      _apply_group_offloading_leaf_level(
161-     #          module, offload_device, onload_device, force_offload, non_blocking, stream=stream
162-     #      )
245+     elif  offload_type  ==  "leaf_level" :
246+         _apply_group_offloading_leaf_level (
247+             module , offload_device , onload_device , force_offload , non_blocking , stream = stream 
248+         )
163249
164250
165251def  _apply_group_offloading_block_level (
@@ -205,12 +291,13 @@ def _apply_group_offloading_block_level(
205291            unmatched_modules .append ((name , submodule ))
206292            continue 
207293        for  i  in  range (0 , len (submodule ), num_blocks_per_group ):
294+             current_modules  =  submodule [i  : i  +  num_blocks_per_group ]
208295            group  =  ModuleGroup (
209296                modules = submodule [i  : i  +  num_blocks_per_group ],
210297                offload_device = offload_device ,
211298                onload_device = onload_device ,
212-                 offload_leader = submodule [ i ],
213-                 onload_leader = None ,
299+                 offload_leader = current_modules [ - 1 ],
300+                 onload_leader = current_modules [ 0 ] ,
214301                non_blocking = non_blocking ,
215302                stream = stream ,
216303                cpu_param_dict = cpu_param_dict ,
@@ -223,7 +310,9 @@ def _apply_group_offloading_block_level(
223310            matched_module_groups [i  +  1 ] if  i  +  1  <  len (matched_module_groups ) and  stream  is  not None  else  None 
224311        )
225312        should_offload  =  force_offload  or  i  >  0 
226-         _apply_group_offloading (group , should_offload , next_group )
313+ 
314+         for  group_module  in  group .modules :
315+             _apply_group_offloading_hook (group_module , group , should_offload , next_group )
227316
228317    parameters  =  []
229318    for  name , parameter  in  module .named_parameters (recurse = False ):
@@ -241,50 +330,121 @@ def _apply_group_offloading_block_level(
241330        offload_device = offload_device ,
242331        onload_device = onload_device ,
243332        offload_leader = module ,
244-         onload_leader = None ,
333+         onload_leader = module ,
334+         parameters = parameters ,
335+         buffers = buffers ,
336+         non_blocking = False ,
337+         stream = None ,
338+         cpu_param_dict = None ,
339+         onload_self = True ,
340+     )
341+     _apply_group_offloading_hook (module , unmatched_group , force_offload , matched_module_groups [0 ])
342+ 
343+ 
344+ def  _apply_group_offloading_leaf_level (
345+     module : torch .nn .Module ,
346+     offload_device : torch .device ,
347+     onload_device : torch .device ,
348+     force_offload : bool ,
349+     non_blocking : bool ,
350+     stream : Optional [torch .cuda .Stream ] =  None ,
351+ ) ->  None :
352+     r""" 
353+     This function applies offloading to groups of leaf modules in a torch.nn.Module. 
354+ 
355+     Args: 
356+         module (`torch.nn.Module`): 
357+             The module to which group offloading is applied. 
358+         offload_device (`torch.device`): 
359+             The device to which the group of modules are offloaded. This should typically be the CPU. 
360+         onload_device (`torch.device`): 
361+             The device to which the group of modules are onloaded. 
362+         force_offload (`bool`): 
363+             If True, all module groups are offloaded to the offload_device. If False, only layers that match 
364+             `offload_group_patterns` are offloaded to the offload_device. 
365+         non_blocking (`bool`): 
366+             If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation 
367+             and data transfer. 
368+         stream (`torch.cuda.Stream`, *optional*): 
369+             If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful 
370+             for overlapping computation and data transfer. 
371+     """ 
372+ 
373+     cpu_param_dict  =  None 
374+     if  stream  is  not None :
375+         for  param  in  module .parameters ():
376+             param .data  =  param .data .cpu ().pin_memory ()
377+         cpu_param_dict  =  {param : param .data  for  param  in  module .parameters ()}
378+ 
379+     for  submodule  in  module .modules ():
380+         if  len (list (submodule .children ())) !=  0 :
381+             continue 
382+         group  =  ModuleGroup (
383+             modules = [submodule ],
384+             offload_device = offload_device ,
385+             onload_device = onload_device ,
386+             offload_leader = submodule ,
387+             onload_leader = submodule ,
388+             non_blocking = non_blocking ,
389+             stream = stream ,
390+             cpu_param_dict = cpu_param_dict ,
391+             onload_self = True ,
392+         )
393+         _apply_group_offloading_hook (submodule , group , True , None )
394+ 
395+     parameters  =  []
396+     buffers  =  []
397+ 
398+     def  gather_non_module_parameters_and_buffers (m : torch .nn .Module ):
399+         if  len (list (m .children ())) ==  0 :
400+             return 
401+         for  parameter  in  m .parameters (recurse = False ):
402+             parameters .append (parameter )
403+         for  buffer  in  m .buffers (recurse = False ):
404+             buffers .append (buffer )
405+         for  submodule  in  m .children ():
406+             gather_non_module_parameters_and_buffers (submodule )
407+ 
408+     gather_non_module_parameters_and_buffers (module )
409+     unmatched_group  =  ModuleGroup (
410+         modules = [],
411+         offload_device = offload_device ,
412+         onload_device = onload_device ,
413+         offload_leader = module ,
414+         onload_leader = module ,
245415        parameters = parameters ,
246416        buffers = buffers ,
247417        non_blocking = False ,
248418        stream = None ,
249419        cpu_param_dict = cpu_param_dict ,
250420        onload_self = True ,
251421    )
252-     _apply_group_offloading (unmatched_group , force_offload , matched_module_groups [0 ])
253- 
254- 
255- # def _apply_group_offloading_leaf_level( 
256- #     module: torch.nn.Module, 
257- #     offload_device: torch.device, 
258- #     onload_device: torch.device, 
259- #     force_offload: bool, 
260- #     non_blocking: bool, 
261- #     stream: Optional[torch.cuda.Stream] = None, 
262- # ) -> None: 
263- #     r""" 
264- # This function applies offloading to groups of leaf modules in a torch.nn.Module. 
265- 
266- # Args: # module (`torch.nn.Module`): # The module to which group offloading is applied. # offload_device 
267- (`torch.device` ): # The device to which the group of modules are offloaded. This should typically be the CPU. # 
268- onload_device  (`torch.device` ): # The device to which the group of modules are onloaded. # force_offload (`bool`): # If 
269- True , all  module  groups  are  offloaded  to  the  offload_device . If  False , only  layers  that  match  # 
270- `offload_group_patterns`  are  offloaded  to  the  offload_device . # non_blocking (`bool`): # If True, offloading and 
271- onloading  is  done  asynchronously . This  can  be  useful  for  overlapping  computation  # and data transfer. # stream 
272- (`torch.cuda.Stream` , * optional * ): # If provided, offloading and onloading is done asynchronously using the provided 
273- stream . This  can  be  useful  # for overlapping computation and data transfer. #""" 
274- 
275- #     cpu_param_dict = None 
276- #     if stream is not None: 
277- #         for param in module.parameters(): 
278- #             param.data = param.data.cpu().pin_memory() 
279- #         cpu_param_dict = {param: param.data for param in module.parameters()} 
280- 
281- 
282- def  _apply_group_offloading (
422+ 
423+     if  stream  is  None :
424+         _apply_group_offloading_hook (module , unmatched_group , force_offload , None )
425+     else :
426+         _apply_lazy_group_offloading_hook (module , unmatched_group , force_offload , None )
427+ 
428+ 
429+ def  _apply_group_offloading_hook (
430+     module : torch .nn .Module ,
431+     group : ModuleGroup ,
432+     offload_on_init : bool ,
433+     next_group : Optional [ModuleGroup ] =  None ,
434+ ) ->  None :
435+     hook  =  GroupOffloadingHook (group , offload_on_init , next_group )
436+     registry  =  HookRegistry .check_if_exists_or_initialize (module )
437+     registry .register_hook (hook , _GROUP_OFFLOADING )
438+ 
439+ 
440+ def  _apply_lazy_group_offloading_hook (
441+     module : torch .nn .Module ,
283442    group : ModuleGroup ,
284443    offload_on_init : bool ,
285444    next_group : Optional [ModuleGroup ] =  None ,
286445) ->  None :
287-     for  module  in  group .modules :
288-         hook  =  GroupOffloadingHook (group , offload_on_init , next_group )
289-         registry  =  HookRegistry .check_if_exists_or_initialize (module )
290-         registry .register_hook (hook , "group_offloading" )
446+     hook  =  GroupOffloadingHook (group , offload_on_init , next_group )
447+     lazy_prefetch_hook  =  LazyPrefetchGroupOffloadingHook ()
448+     registry  =  HookRegistry .check_if_exists_or_initialize (module )
449+     registry .register_hook (hook , _GROUP_OFFLOADING )
450+     registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
0 commit comments