Skip to content

Commit 831de55

Browse files
committed
Proper bidirectional implementation
1 parent 1ccbeba commit 831de55

File tree

4 files changed

+56
-9
lines changed

4 files changed

+56
-9
lines changed

lib/axon.ex

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,17 +2357,31 @@ defmodule Axon do
23572357
"""
23582358
def bidirectional(%Axon{} = input, forward_fun, merge_fun, opts \\ [])
23592359
when is_function(forward_fun, 1) and is_function(merge_fun, 2) do
2360-
opts = Keyword.validate!(opts, axis: 1)
2360+
opts = Keyword.validate!(opts, [:name, axis: 1])
23612361

2362-
forward_out = forward_fun.(input)
2362+
fun =
2363+
Axon.block(
2364+
fn x ->
2365+
Axon.container(forward_fun.(x))
2366+
end,
2367+
name: opts[:name]
2368+
)
2369+
2370+
forward_out = fun.(input)
23632371

23642372
backward_out =
23652373
input
23662374
|> Axon.nx(&Nx.reverse(&1, axes: [opts[:axis]]))
2367-
|> forward_fun.()
2368-
|> deep_new(&Axon.nx(&1, fn x -> Nx.reverse(x, axes: [opts[:axis]]) end))
2375+
|> fun.()
2376+
|> Axon.nx(fn x ->
2377+
deep_new(x, &Nx.reverse(&1, axes: [opts[:axis]]))
2378+
end)
23692379

2370-
deep_merge(forward_out, backward_out, merge_fun)
2380+
{forward_out, backward_out}
2381+
|> Axon.container()
2382+
|> Axon.nx(fn {forward, backward} ->
2383+
deep_merge(forward, backward, merge_fun)
2384+
end)
23712385
end
23722386

23732387
@doc """

lib/axon/model_state.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ defmodule Axon.ModelState do
2323
} = model_state,
2424
updated_parameters,
2525
updated_state \\ %{}
26-
) do
26+
) do
2727
updated_state =
2828
state
2929
|> tree_diff(frozen)
@@ -215,7 +215,7 @@ defmodule Axon.ModelState do
215215
Enum.reduce(access, %{}, &Map.put(&2, &1, Map.fetch!(data, &1)))
216216
end
217217

218-
defp tree_get(data, access) when is_map(access) do
218+
defp tree_get(data, access) when is_map(access) do
219219
Enum.reduce(access, %{}, fn {key, value}, acc ->
220220
tree = tree_get(data[key], value)
221221
Map.put(acc, key, tree)

test/axon/compiler_test.exs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5650,4 +5650,33 @@ defmodule CompilerTest do
56505650
predict_fn.(params, input)
56515651
end
56525652
end
5653+
5654+
describe "bidirectional" do
5655+
test "works properly with LSTMs" do
5656+
input = Axon.input("input")
5657+
5658+
model =
5659+
input
5660+
|> Axon.embedding(10, 16)
5661+
|> Axon.bidirectional(
5662+
&Axon.lstm(&1, 32, name: "lstm"),
5663+
&Nx.concatenate([&1, &2], axis: 1),
5664+
name: "bidirectional"
5665+
)
5666+
|> Axon.nx(&elem(&1, 0))
5667+
5668+
{init_fn, predict_fn} = Axon.build(model)
5669+
5670+
input = Nx.broadcast(1, {1, 10})
5671+
5672+
assert %ModelState{
5673+
data: %{
5674+
"bidirectional" => %{"lstm" => _}
5675+
}
5676+
} = params = init_fn.(input, ModelState.empty())
5677+
5678+
out = predict_fn.(params, input)
5679+
assert Nx.shape(out) == {1, 20, 32}
5680+
end
5681+
end
56535682
end

test/axon/integration_test.exs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,9 @@ defmodule Axon.IntegrationTest do
485485
assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60)
486486
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
487487
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
488-
assert Nx.type(model_state.data["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params
488+
489+
assert Nx.type(model_state.data["dense_0"]["kernel"]) ==
490+
unquote(Macro.escape(policy)).params
489491
end)
490492
end
491493

@@ -536,7 +538,9 @@ defmodule Axon.IntegrationTest do
536538
assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.60)
537539
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
538540
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
539-
assert Nx.type(model_state.data["dense_0"]["kernel"]) == unquote(Macro.escape(policy)).params
541+
542+
assert Nx.type(model_state.data["dense_0"]["kernel"]) ==
543+
unquote(Macro.escape(policy)).params
540544
end)
541545
end
542546
end

0 commit comments

Comments
 (0)