Skip to content

Commit 19dabfa

Browse files
committed
[API Compatibility]Add prepend for register_forward_pre_hook, add prepend with_kwargs always_call for register_forward_post_hook
1 parent 0ff7d40 commit 19dabfa

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

python/paddle/nn/layer/layers.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,13 @@
6767

6868
_ForwardPreHook = Union[
6969
Callable[["Layer", Tensor], Tensor], # (layer, input) -> transformed_input
70-
Callable[["Layer", Tensor, dict[str, Any]], tuple[Tensor, dict[str, Any]]]
70+
Callable[["Layer", Tensor, dict[str, Any]], tuple[Tensor, dict[str, Any]]],
7171
]
7272
_ForwardPostHook = Union[
7373
Callable[
7474
["Layer", Tensor, Tensor], Tensor
7575
], # (layer, input, output) -> transformed_output
76-
Callable[["Layer", Tensor, dict[str, Any], Tensor], Tensor]
76+
Callable[["Layer", Tensor, dict[str, Any], Tensor], Tensor],
7777
]
7878
_StateDict = Union[dict[str, Tensor], typing.OrderedDict[str, Tensor]]
7979
_StateDictHook = Callable[[_StateDict], None]
@@ -739,8 +739,8 @@ def register_forward_post_hook(
739739
self._forward_post_hooks,
740740
extra_hook_dict=[
741741
self._forward_post_hooks_with_kwargs_flag,
742-
self._forward_post_hooks_always_called
743-
]
742+
self._forward_post_hooks_always_called,
743+
],
744744
)
745745
self._forward_post_hooks[hook_remove_helper._hook_id] = hook
746746
if with_kwargs:
@@ -1625,7 +1625,9 @@ def inner():
16251625
called_always_called_hooks.add(hook_id)
16261626

16271627
if hook_id in self._forward_post_hooks_with_kwargs_flag:
1628-
hook_result = forward_post_hook(self, inputs, kwargs, outputs)
1628+
hook_result = forward_post_hook(
1629+
self, inputs, kwargs, outputs
1630+
)
16291631
else:
16301632
hook_result = forward_post_hook(self, inputs, outputs)
16311633

@@ -1639,20 +1641,25 @@ def inner():
16391641
except Exception:
16401642
for hook_id, forward_post_hook in self._forward_post_hooks.items():
16411643
if (
1642-
(hook_id in self._forward_post_hooks_always_called)
1643-
and hook_id not in called_always_called_hooks
1644-
):
1644+
hook_id in self._forward_post_hooks_always_called
1645+
) and hook_id not in called_always_called_hooks:
16451646
try:
16461647
if hook_id in self._forward_post_hooks_with_kwargs_flag:
1647-
hook_result = forward_post_hook(self, inputs, kwargs, outputs)
1648+
hook_result = forward_post_hook(
1649+
self, inputs, kwargs, outputs
1650+
)
16481651
else:
1649-
hook_result = forward_post_hook(self, inputs, outputs)
1652+
hook_result = forward_post_hook(
1653+
self, inputs, outputs
1654+
)
16501655

16511656
if hook_result is not None:
16521657
outputs = hook_result
16531658
except Exception as e:
1654-
warnings.warn("forward hook with ``always_call=True`` raised an exception "
1655-
f"that was silenced as another error was raised in forward: {str(e)}")
1659+
warnings.warn(
1660+
"forward hook with ``always_call=True`` raised an exception "
1661+
f"that was silenced as another error was raised in forward: {e!s}"
1662+
)
16561663
continue
16571664
# raise exception raised in try block
16581665
raise

0 commit comments

Comments
 (0)