@@ -5,6 +5,45 @@ defmodule Axon.IntegrationTest do
5
5
6
6
@ moduletag :integration
7
7
8
+ test "bce with simple xor model" do
9
+ x1_input = Axon . input ( "x1" , shape: { nil , 1 } )
10
+ x2_input = Axon . input ( "x2" , shape: { nil , 1 } )
11
+
12
+ model =
13
+ x1_input
14
+ |> Axon . concatenate ( x2_input )
15
+ |> Axon . dense ( 8 , activation: :tanh )
16
+ |> Axon . dense ( 1 , activation: :sigmoid )
17
+
18
+ batch_size = 32
19
+
20
+ data =
21
+ Stream . unfold ( Nx.Random . key ( 42 ) , fn key ->
22
+ { x1 , key } = Nx.Random . uniform ( key , 0 , 1 , shape: { batch_size , 1 } )
23
+ { x2 , key } = Nx.Random . uniform ( key , 0 , 1 , shape: { batch_size , 1 } )
24
+
25
+ { x1 , x2 } = { Nx . round ( x1 ) , Nx . round ( x2 ) }
26
+ y = Nx . logical_xor ( x1 , x2 )
27
+
28
+ { { % { "x1" => x1 , "x2" => x2 } , y } , key }
29
+ end )
30
+
31
+ ExUnit.CaptureIO . capture_io ( fn ->
32
+ model_state =
33
+ model
34
+ |> Axon.Loop . trainer ( :binary_cross_entropy , :sgd )
35
+ |> Axon.Loop . run ( data , Axon.ModelState . empty ( ) , iterations: 100 , epochs: 10 )
36
+
37
+ eval_results =
38
+ model
39
+ |> Axon.Loop . evaluator ( )
40
+ |> Axon.Loop . metric ( :accuracy )
41
+ |> Axon.Loop . run ( data , model_state , iterations: 100 )
42
+
43
+ assert_greater_equal ( get_in ( eval_results , [ 0 , "accuracy" ] ) , 0.9 )
44
+ end )
45
+ end
46
+
8
47
test "vector classification test" do
9
48
{ train , _test } = get_test_data ( 100 , 0 , 10 , { 10 } , 2 , 1337 )
10
49
0 commit comments