67
67
68
68
_ForwardPreHook = Union [
69
69
Callable [["Layer" , Tensor ], Tensor ], # (layer, input) -> transformed_input
70
- Callable [["Layer" , Tensor , dict [str , Any ]], tuple [Tensor , dict [str , Any ]]]
70
+ Callable [["Layer" , Tensor , dict [str , Any ]], tuple [Tensor , dict [str , Any ]]],
71
71
]
72
72
_ForwardPostHook = Union [
73
73
Callable [
74
74
["Layer" , Tensor , Tensor ], Tensor
75
75
], # (layer, input, output) -> transformed_output
76
- Callable [["Layer" , Tensor , dict [str , Any ], Tensor ], Tensor ]
76
+ Callable [["Layer" , Tensor , dict [str , Any ], Tensor ], Tensor ],
77
77
]
78
78
_StateDict = Union [dict [str , Tensor ], typing .OrderedDict [str , Tensor ]]
79
79
_StateDictHook = Callable [[_StateDict ], None ]
@@ -739,8 +739,8 @@ def register_forward_post_hook(
739
739
self ._forward_post_hooks ,
740
740
extra_hook_dict = [
741
741
self ._forward_post_hooks_with_kwargs_flag ,
742
- self ._forward_post_hooks_always_called
743
- ]
742
+ self ._forward_post_hooks_always_called ,
743
+ ],
744
744
)
745
745
self ._forward_post_hooks [hook_remove_helper ._hook_id ] = hook
746
746
if with_kwargs :
@@ -1625,7 +1625,9 @@ def inner():
1625
1625
called_always_called_hooks .add (hook_id )
1626
1626
1627
1627
if hook_id in self ._forward_post_hooks_with_kwargs_flag :
1628
- hook_result = forward_post_hook (self , inputs , kwargs , outputs )
1628
+ hook_result = forward_post_hook (
1629
+ self , inputs , kwargs , outputs
1630
+ )
1629
1631
else :
1630
1632
hook_result = forward_post_hook (self , inputs , outputs )
1631
1633
@@ -1639,20 +1641,25 @@ def inner():
1639
1641
except Exception :
1640
1642
for hook_id , forward_post_hook in self ._forward_post_hooks .items ():
1641
1643
if (
1642
- (hook_id in self ._forward_post_hooks_always_called )
1643
- and hook_id not in called_always_called_hooks
1644
- ):
1644
+ hook_id in self ._forward_post_hooks_always_called
1645
+ ) and hook_id not in called_always_called_hooks :
1645
1646
try :
1646
1647
if hook_id in self ._forward_post_hooks_with_kwargs_flag :
1647
- hook_result = forward_post_hook (self , inputs , kwargs , outputs )
1648
+ hook_result = forward_post_hook (
1649
+ self , inputs , kwargs , outputs
1650
+ )
1648
1651
else :
1649
- hook_result = forward_post_hook (self , inputs , outputs )
1652
+ hook_result = forward_post_hook (
1653
+ self , inputs , outputs
1654
+ )
1650
1655
1651
1656
if hook_result is not None :
1652
1657
outputs = hook_result
1653
1658
except Exception as e :
1654
- warnings .warn ("forward hook with ``always_call=True`` raised an exception "
1655
- f"that was silenced as another error was raised in forward: { str (e )} " )
1659
+ warnings .warn (
1660
+ "forward hook with ``always_call=True`` raised an exception "
1661
+ f"that was silenced as another error was raised in forward: { e !s} "
1662
+ )
1656
1663
continue
1657
1664
# raise exception raised in try block
1658
1665
raise
0 commit comments