Skip to content

Commit 216fafe

Browse files
authored
Add layer name to hook (#536)
1 parent b93e87f commit 216fafe

File tree

2 files changed

+31
-11
lines changed

2 files changed

+31
-11
lines changed

lib/axon/compiler.ex

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,8 @@ defmodule Axon.Compiler do
501501

502502
res =
503503
value
504-
|> apply_hooks(:forward, mode, hooks)
505-
|> apply_hooks(:backward, mode, hooks)
504+
|> apply_hooks(name, :forward, mode, hooks)
505+
|> apply_hooks(name, :backward, mode, hooks)
506506
|> maybe_print_values(name, print_values)
507507

508508
{res, {state, result_cache}}
@@ -975,7 +975,7 @@ defmodule Axon.Compiler do
975975
layer_input =
976976
layer_input
977977
|> safe_policy_cast(policy, :compute)
978-
|> apply_hooks(:pre_forward, mode, hooks)
978+
|> apply_hooks(name, :pre_forward, mode, hooks)
979979

980980
{layer_input, {state, result_cache, none?}}
981981
end
@@ -1051,8 +1051,8 @@ defmodule Axon.Compiler do
10511051
%StatefulOutput{output: out, state: out_state} ->
10521052
new_out =
10531053
out
1054-
|> apply_hooks(:forward, mode, hooks)
1055-
|> apply_hooks(:backward, mode, hooks)
1054+
|> apply_hooks(name, :forward, mode, hooks)
1055+
|> apply_hooks(name, :backward, mode, hooks)
10561056
|> safe_policy_cast(policy, :output)
10571057

10581058
new_state = Map.put(state, name, out_state)
@@ -1061,8 +1061,8 @@ defmodule Axon.Compiler do
10611061
out ->
10621062
new_out =
10631063
out
1064-
|> apply_hooks(:forward, mode, hooks)
1065-
|> apply_hooks(:backward, mode, hooks)
1064+
|> apply_hooks(name, :forward, mode, hooks)
1065+
|> apply_hooks(name, :backward, mode, hooks)
10661066
|> safe_policy_cast(policy, :output)
10671067

10681068
{new_out, state}
@@ -1169,7 +1169,7 @@ defmodule Axon.Compiler do
11691169
init_param(layer_id, param, layer_params, parent_templates, dtype, keys)
11701170
end)
11711171

1172-
layer_params = apply_hooks(layer_params, :initialize, nil, hooks)
1172+
layer_params = apply_hooks(layer_params, name, :initialize, nil, hooks)
11731173

11741174
params =
11751175
if layer_params == %{} do
@@ -1228,7 +1228,7 @@ defmodule Axon.Compiler do
12281228

12291229
defp maybe_print_values(value, _, _), do: value
12301230

1231-
defp apply_hooks(res, event, mode, hooks) do
1231+
defp apply_hooks(res, layer_name, event, mode, hooks) do
12321232
hooks
12331233
|> Enum.reverse()
12341234
|> Enum.reduce(res, fn {on_event, on_mode, hook_fn}, expr ->
@@ -1238,11 +1238,11 @@ defmodule Axon.Compiler do
12381238
if event? and mode? do
12391239
if on_event == :backward do
12401240
Nx.Defn.Kernel.custom_grad(expr, [expr], fn g ->
1241-
hooked_g = Nx.Defn.Kernel.hook(g, hook_fn)
1241+
hooked_g = Nx.Defn.Kernel.hook(g, String.to_atom(layer_name), hook_fn)
12421242
[hooked_g]
12431243
end)
12441244
else
1245-
Nx.Defn.Kernel.hook(expr, hook_fn)
1245+
Nx.Defn.Kernel.hook(expr, String.to_atom(layer_name), hook_fn)
12461246
end
12471247
else
12481248
expr

test/axon/compiler_test.exs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4726,6 +4726,26 @@ defmodule CompilerTest do
47264726
assert_receive {%Nx.Tensor{}, :from_relu}
47274727
assert_receive {%Nx.Tensor{}, :from_sigmoid}
47284728
end
4729+
4730+
test "can be overriden at jit-time with layer name", config do
4731+
model =
4732+
Axon.input("input_0", shape: {nil, 1})
4733+
|> Axon.attach_hook(fn x -> send(config.test, {x, :from_input}) end, on: :forward)
4734+
|> Axon.relu()
4735+
4736+
inp = Nx.tensor([[1.0]])
4737+
{_, predict_fn} = Axon.build(model)
4738+
4739+
hook = fn val -> send(config.test, {val, :overridden}) end
4740+
4741+
fun = Nx.Defn.jit(predict_fn, hooks: %{input_0: hook})
4742+
apply(fun, [ModelState.empty(), inp])
4743+
4744+
assert_receive {from_inp, :overridden}
4745+
refute_receive {_, :from_input}
4746+
4747+
assert_equal(from_inp, inp)
4748+
end
47294749
end
47304750

47314751
describe "integrated models" do

0 commit comments

Comments
 (0)