@@ -501,8 +501,8 @@ defmodule Axon.Compiler do
501
501
502
502
res =
503
503
value
504
- |> apply_hooks ( :forward , mode , hooks )
505
- |> apply_hooks ( :backward , mode , hooks )
504
+ |> apply_hooks ( name , :forward , mode , hooks )
505
+ |> apply_hooks ( name , :backward , mode , hooks )
506
506
|> maybe_print_values ( name , print_values )
507
507
508
508
{ res , { state , result_cache } }
@@ -975,7 +975,7 @@ defmodule Axon.Compiler do
975
975
layer_input =
976
976
layer_input
977
977
|> safe_policy_cast ( policy , :compute )
978
- |> apply_hooks ( :pre_forward , mode , hooks )
978
+ |> apply_hooks ( name , :pre_forward , mode , hooks )
979
979
980
980
{ layer_input , { state , result_cache , none? } }
981
981
end
@@ -1051,8 +1051,8 @@ defmodule Axon.Compiler do
1051
1051
% StatefulOutput { output: out , state: out_state } ->
1052
1052
new_out =
1053
1053
out
1054
- |> apply_hooks ( :forward , mode , hooks )
1055
- |> apply_hooks ( :backward , mode , hooks )
1054
+ |> apply_hooks ( name , :forward , mode , hooks )
1055
+ |> apply_hooks ( name , :backward , mode , hooks )
1056
1056
|> safe_policy_cast ( policy , :output )
1057
1057
1058
1058
new_state = Map . put ( state , name , out_state )
@@ -1061,8 +1061,8 @@ defmodule Axon.Compiler do
1061
1061
out ->
1062
1062
new_out =
1063
1063
out
1064
- |> apply_hooks ( :forward , mode , hooks )
1065
- |> apply_hooks ( :backward , mode , hooks )
1064
+ |> apply_hooks ( name , :forward , mode , hooks )
1065
+ |> apply_hooks ( name , :backward , mode , hooks )
1066
1066
|> safe_policy_cast ( policy , :output )
1067
1067
1068
1068
{ new_out , state }
@@ -1169,7 +1169,7 @@ defmodule Axon.Compiler do
1169
1169
init_param ( layer_id , param , layer_params , parent_templates , dtype , keys )
1170
1170
end )
1171
1171
1172
- layer_params = apply_hooks ( layer_params , :initialize , nil , hooks )
1172
+ layer_params = apply_hooks ( layer_params , name , :initialize , nil , hooks )
1173
1173
1174
1174
params =
1175
1175
if layer_params == % { } do
@@ -1228,7 +1228,7 @@ defmodule Axon.Compiler do
1228
1228
1229
1229
defp maybe_print_values ( value , _ , _ ) , do: value
1230
1230
1231
- defp apply_hooks ( res , event , mode , hooks ) do
1231
+ defp apply_hooks ( res , layer_name , event , mode , hooks ) do
1232
1232
hooks
1233
1233
|> Enum . reverse ( )
1234
1234
|> Enum . reduce ( res , fn { on_event , on_mode , hook_fn } , expr ->
@@ -1238,11 +1238,11 @@ defmodule Axon.Compiler do
1238
1238
if event? and mode? do
1239
1239
if on_event == :backward do
1240
1240
Nx.Defn.Kernel . custom_grad ( expr , [ expr ] , fn g ->
1241
- hooked_g = Nx.Defn.Kernel . hook ( g , hook_fn )
1241
+ hooked_g = Nx.Defn.Kernel . hook ( g , String . to_atom ( layer_name ) , hook_fn )
1242
1242
[ hooked_g ]
1243
1243
end )
1244
1244
else
1245
- Nx.Defn.Kernel . hook ( expr , hook_fn )
1245
+ Nx.Defn.Kernel . hook ( expr , String . to_atom ( layer_name ) , hook_fn )
1246
1246
end
1247
1247
else
1248
1248
expr
0 commit comments