@@ -181,6 +181,13 @@ def __init__(self):
181181 self ._layer_execution_tracker_module_names = set ()
182182
183183 def initialize_hook (self , module ):
184+ def make_execution_order_update_callback (current_name , current_submodule ):
185+ def callback ():
186+ logger .debug (f"Adding { current_name } to the execution order" )
187+ self .execution_order .append ((current_name , current_submodule ))
188+
189+ return callback
190+
184191 # To every submodule that contains a group offloading hook (at this point, no prefetching is enabled for any
185192 # of the groups), we add a layer execution tracker hook that will be used to determine the order in which the
186193 # layers are executed during the forward pass.
@@ -192,14 +199,8 @@ def initialize_hook(self, module):
192199 group_offloading_hook = registry .get_hook (_GROUP_OFFLOADING )
193200
194201 if group_offloading_hook is not None :
195-
196- def make_execution_order_update_callback (current_name , current_submodule ):
197- def callback ():
198- logger .debug (f"Adding { current_name } to the execution order" )
199- self .execution_order .append ((current_name , current_submodule ))
200-
201- return callback
202-
202+ # For the first forward pass, we have to load in a blocking manner
203+ group_offloading_hook .group .non_blocking = False
203204 layer_tracker_hook = LayerExecutionTrackerHook (make_execution_order_update_callback (name , submodule ))
204205 registry .register_hook (layer_tracker_hook , _LAYER_EXECUTION_TRACKER )
205206 self ._layer_execution_tracker_module_names .add (name )
@@ -229,15 +230,21 @@ def post_forward(self, module, output):
229230 # Remove the layer execution tracker hooks from the submodules
230231 base_module_registry = module ._diffusers_hook
231232 registries = [submodule ._diffusers_hook for _ , submodule in self .execution_order ]
233+ group_offloading_hooks = [registry .get_hook (_GROUP_OFFLOADING ) for registry in registries ]
232234
233235 for i in range (num_executed ):
234236 registries [i ].remove_hook (_LAYER_EXECUTION_TRACKER , recurse = False )
235237
236238 # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
237239 base_module_registry .remove_hook (_LAZY_PREFETCH_GROUP_OFFLOADING , recurse = False )
238240
239- # Apply lazy prefetching by setting required attributes
240- group_offloading_hooks = [registry .get_hook (_GROUP_OFFLOADING ) for registry in registries ]
241+ # LazyPrefetchGroupOffloadingHook is only used with streams, so we know that non_blocking should be True.
242+ # We disable non_blocking for the first forward pass, but need to enable it for the subsequent passes to
243+ # see the benefits of prefetching.
244+ for hook in group_offloading_hooks :
245+ hook .group .non_blocking = True
246+
247+ # Set required attributes for prefetching
241248 if num_executed > 0 :
242249 base_module_group_offloading_hook = base_module_registry .get_hook (_GROUP_OFFLOADING )
243250 base_module_group_offloading_hook .next_group = group_offloading_hooks [0 ].group
0 commit comments