@@ -218,10 +218,10 @@ def post_forward(self, module, output):
218218 registries = [submodule ._diffusers_hook for _ , submodule in self .execution_order ]
219219
220220 for i in range (num_executed ):
221- registries [i ].remove_hook (_LAYER_EXECUTION_TRACKER )
221+ registries [i ].remove_hook (_LAYER_EXECUTION_TRACKER , recurse = False )
222222
223223 # Remove the current lazy prefetch group offloading hook so that it doesn't interfere with the next forward pass
224- base_module_registry .remove_hook (_LAZY_PREFETCH_GROUP_OFFLOADING )
224+ base_module_registry .remove_hook (_LAZY_PREFETCH_GROUP_OFFLOADING , recurse = False )
225225
226226 # Apply lazy prefetching by setting required attributes
227227 group_offloading_hooks = [registry .get_hook (_GROUP_OFFLOADING ) for registry in registries ]
@@ -536,7 +536,10 @@ def _apply_lazy_group_offloading_hook(
536536 hook = GroupOffloadingHook (group , offload_on_init , next_group )
537537 lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook ()
538538 registry = HookRegistry .check_if_exists_or_initialize (module )
539- registry .register_hook (hook , _GROUP_OFFLOADING )
539+ # We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
540+ # is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
541+ if registry .get_hook (_GROUP_OFFLOADING ) is None :
542+ registry .register_hook (hook , _GROUP_OFFLOADING )
540543 registry .register_hook (lazy_prefetch_hook , _LAZY_PREFETCH_GROUP_OFFLOADING )
541544
542545
0 commit comments