Skip to content

Commit c90a0e2

Browse files
committed
update
1 parent c53a4ab commit c90a0e2

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

tests/hooks/test_hooks.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def new_forward(self, module, *args, **kwargs):
126126
logger.debug("SkipLayerHook new_forward")
127127
if self.skip_layer:
128128
return args[0]
129-
return self.fn_ref.overwritten_forward(*args, **kwargs)
129+
return self.fn_ref.original_forward(*args, **kwargs)
130130

131131
def post_forward(self, module, output):
132132
logger.debug("SkipLayerHook post_forward")
@@ -174,14 +174,12 @@ def test_hook_registry(self):
174174

175175
self.assertEqual(len(registry.hooks), 2)
176176
self.assertEqual(registry._hook_order, ["add_hook", "multiply_hook"])
177-
self.assertEqual(len(registry._fn_refs), 2)
178177
self.assertEqual(registry_repr, expected_repr)
179178

180179
registry.remove_hook("add_hook")
181180

182181
self.assertEqual(len(registry.hooks), 1)
183182
self.assertEqual(registry._hook_order, ["multiply_hook"])
184-
self.assertEqual(len(registry._fn_refs), 1)
185183

186184
def test_stateful_hook(self):
187185
registry = HookRegistry.check_if_exists_or_initialize(self.model)

0 commit comments

Comments
 (0)