Skip to content

Commit 5f9e7bc

Browse files
committed
Add BCE integration test
1 parent 603818f commit 5f9e7bc

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

test/axon/integration_test.exs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,45 @@ defmodule Axon.IntegrationTest do
55

66
@moduletag :integration
77

8+
test "bce with simple xor model" do
9+
x1_input = Axon.input("x1", shape: {nil, 1})
10+
x2_input = Axon.input("x2", shape: {nil, 1})
11+
12+
model =
13+
x1_input
14+
|> Axon.concatenate(x2_input)
15+
|> Axon.dense(8, activation: :tanh)
16+
|> Axon.dense(1, activation: :sigmoid)
17+
18+
batch_size = 32
19+
20+
data =
21+
Stream.unfold(Nx.Random.key(42), fn key ->
22+
{x1, key} = Nx.Random.uniform(key, 0, 1, shape: {batch_size, 1})
23+
{x2, key} = Nx.Random.uniform(key, 0, 1, shape: {batch_size, 1})
24+
25+
{x1, x2} = {Nx.round(x1), Nx.round(x2)}
26+
y = Nx.logical_xor(x1, x2)
27+
28+
{{%{"x1" => x1, "x2" => x2}, y}, key}
29+
end)
30+
31+
ExUnit.CaptureIO.capture_io(fn ->
32+
model_state =
33+
model
34+
|> Axon.Loop.trainer(:binary_cross_entropy, :sgd)
35+
|> Axon.Loop.run(data, Axon.ModelState.empty(), iterations: 100, epochs: 10)
36+
37+
eval_results =
38+
model
39+
|> Axon.Loop.evaluator()
40+
|> Axon.Loop.metric(:accuracy)
41+
|> Axon.Loop.run(data, model_state, iterations: 100)
42+
43+
assert_greater_equal(get_in(eval_results, [0, "accuracy"]), 0.9)
44+
end)
45+
end
46+
847
test "vector classification test" do
948
{train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337)
1049

0 commit comments

Comments
 (0)