@@ -450,6 +450,31 @@ def test_enable_disable_hook(self):
450450        self .assertNotEqual (output1 , output2 )
451451        self .assertEqual (output1 , output3 )
452452
453+     def  test_enable_disable_hook_containing_new_forward (self ):
454+         registry  =  HookRegistry .check_if_exists_or_initialize (self .model )
455+         registry .register_hook (AddHook (1 ), "add_hook" )
456+         for  block  in  self .model .blocks :
457+             block_registry  =  HookRegistry .check_if_exists_or_initialize (block )
458+             block_registry .register_hook (SkipLayerHook (skip_layer = True ), "skip_layer_hook" )
459+         registry .register_hook (MultiplyHook (2 ), "multiply_hook" )
460+ 
461+         input  =  torch .randn (1 , 4 , device = torch_device , generator = self .get_generator ())
462+         output1  =  self .model (input ).mean ().detach ().cpu ().item ()
463+ 
464+         self .model ._disable_hook ("skip_layer_hook" )
465+         output2  =  self .model (input ).mean ().detach ().cpu ().item ()
466+ 
467+         self .model ._disable_hook ("add_hook" )
468+         output3  =  self .model (input ).mean ().detach ().cpu ().item ()
469+ 
470+         self .model ._enable_hook ("skip_layer_hook" )
471+         self .model ._enable_hook ("add_hook" )
472+         output4  =  self .model (input ).mean ().detach ().cpu ().item ()
473+ 
474+         self .assertNotEqual (output1 , output2 )
475+         self .assertNotEqual (output2 , output3 )
476+         self .assertEqual (output1 , output4 )
477+ 
453478    def  test_remove_all_hooks (self ):
454479        registry  =  HookRegistry .check_if_exists_or_initialize (self .model )
455480        registry .register_hook (AddHook (1 ), "add_hook" )
0 commit comments