Skip to content

Commit e6f38ee

Browse files
committed
add unit test for register_forward_pre_hook & register_forward_post_hook
1 parent 278a848 commit e6f38ee

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

test/legacy_test/test_imperative_hook_for_layer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,20 @@ def throw_hook(m, i, o):
304304
self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1])
305305

306306

307+
# make sure that always called forward hooks are properly removed
308+
forward_post_hook_handle.remove()
309+
forward_post_hook_handle2.remove()
310+
self.assertTrue(len(net._forward_post_hooks_always_called) == 0)
311+
312+
# make sure that always called forward hook is not run twice if it fails while running
313+
forward_post_hook_handle3 = net.register_forward_post_hook(
314+
ctx_shutdown_failure_hook, always_call=True
315+
)
316+
with self.assertRaisesRegex(RuntimeError, "failing in ctx setup"):
317+
net(x, fail=False)
318+
self.assertEqual(stack, [2, -1, 2, -1, 2, -1, 2, -1, 2, -1])
319+
320+
307321
class TestHookWithKWArgs(unittest.TestCase):
308322
def test_kwargs_hook(self):
309323
x = paddle.randn((2, 3))

0 commit comments

Comments
 (0)