|
1 |
| -using Reactant, Lux, Random, Statistics |
2 |
| -using Enzyme |
3 |
| -using Test |
4 |
| - |
5 |
| -# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: |
6 |
| -noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} |
7 |
| -truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} |
8 |
| - |
9 |
| -# Define our model, a multi-layer perceptron with one hidden layer of size 3: |
10 |
| -model = Lux.Chain( |
11 |
| - Lux.Dense(2 => 3, tanh), # activation function inside layer |
12 |
| - Lux.BatchNorm(3, gelu), |
13 |
| - Lux.Dense(3 => 2), |
14 |
| - softmax, |
15 |
| -) |
16 |
| -ps, st = Lux.setup(Xoshiro(123), model) |
17 |
| - |
18 |
| -using BenchmarkTools |
19 |
| - |
20 |
| -origout, _ = model(noisy, ps, st) |
21 |
| -@btime model($noisy, $ps, $st) # 68.444 μs (46 allocations: 45.88 KiB) |
22 |
| - |
23 |
| -cmodel = Reactant.to_rarray(model) |
24 |
| -cps = Reactant.to_rarray(ps) |
25 |
| -cst = Reactant.to_rarray(st) |
26 |
| -cnoisy = Reactant.ConcreteRArray(noisy) |
27 |
| - |
28 |
| -f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cst)) |
29 |
| - |
30 |
| -# # using InteractiveUtils |
31 |
| -# # @show @code_typed f(cmodel,cnoisy) |
32 |
| -# # @show @code_llvm f(cmodel,cnoisy) |
33 |
| -comp = f(cmodel, cnoisy, cps, cst) |
34 |
| -# @btime f($cmodel, $cnoisy, $cps, $cst) # 21.790 μs (6 allocations: 224 bytes) |
35 |
| - |
36 |
| -@test comp ≈ origout atol = 1e-5 rtol = 1e-2 |
37 |
| - |
38 |
| -# To train the model, we use batches of 64 samples, and one-hot encoding: |
39 |
| - |
40 |
| -using MLUtils, OneHotArrays, Optimisers |
41 |
| - |
42 |
| -target = onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix |
43 |
| -ctarget = Reactant.ConcreteRArray(Array{Float32}(target)) |
44 |
| -loader = DataLoader((noisy, target); batchsize=64, shuffle=true); |
45 |
| -# # 16-element DataLoader with first element: (2×64 Matrix{Float32}, 2×64 OneHotMatrix) |
46 |
| - |
47 |
| -opt = Optimisers.Adam(0.01f0) |
48 |
| -losses = [] |
49 |
| - |
50 |
| -# Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the |
51 |
| -# training loop manually: |
52 |
| -function crossentropy(ŷ, y) |
53 |
| - logŷ = log.(ŷ) |
54 |
| - result = y .* logŷ |
55 |
| - # result = ifelse.(y .== 0.0f0, zero.(result), result) |
56 |
| - return -sum(result) |
57 |
| -end |
58 |
| - |
59 |
| -function loss_function(model, x, y, ps, st) |
60 |
| - y_hat, _ = model(x, ps, st) |
61 |
| - return crossentropy(y_hat, y) |
62 |
| -end |
63 |
| - |
64 |
| -function gradient_loss_function(model, x, y, ps, st) |
65 |
| - dps = Enzyme.make_zero(ps) |
66 |
| - _, res = Enzyme.autodiff( |
67 |
| - ReverseWithPrimal, |
68 |
| - loss_function, |
69 |
| - Active, |
70 |
| - Const(model), |
71 |
| - Const(x), |
72 |
| - Const(y), |
73 |
| - Duplicated(ps, dps), |
74 |
| - Const(st), |
| 1 | +using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays |
| 2 | + |
| 3 | +@testset "Lux.jl Integration" begin |
| 4 | + # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix: |
| 5 | + noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32} |
| 6 | + truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool} |
| 7 | + |
| 8 | + # Define our model, a multi-layer perceptron with one hidden layer of size 3: |
| 9 | + model = Lux.Chain( |
| 10 | + Lux.Dense(2 => 3, tanh), # activation function inside layer |
| 11 | + Lux.BatchNorm(3, sigmoid), |
| 12 | + Lux.Dense(3 => 2), |
| 13 | + softmax, |
| 14 | + ) |
| 15 | + ps, st = Lux.setup(Xoshiro(123), model) |
| 16 | + |
| 17 | + origout, _ = model(noisy, ps, Lux.testmode(st)) |
| 18 | + |
| 19 | + cmodel = Reactant.to_rarray(model) |
| 20 | + cps = Reactant.to_rarray(ps) |
| 21 | + cst = Reactant.to_rarray(Lux.testmode(st)) |
| 22 | + cst2 = Reactant.to_rarray(st) |
| 23 | + cnoisy = Reactant.ConcreteRArray(noisy) |
| 24 | + |
| 25 | + f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cst)) |
| 26 | + |
| 27 | + comp = f(cmodel, cnoisy, cps, cst) |
| 28 | + |
| 29 | + @test comp ≈ origout atol = 1e-5 rtol = 1e-2 |
| 30 | + |
| 31 | + target = onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix |
| 32 | + |
| 33 | + ctarget = Reactant.ConcreteRArray(Array{Float32}(target)) |
| 34 | + # ctarget = Reactant.to_rarray(target) |
| 35 | + |
| 36 | + # Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the |
| 37 | + # training loop manually: |
| 38 | + function crossentropy(ŷ, y) |
| 39 | + logŷ = log.(ŷ) |
| 40 | + result = y .* logŷ |
| 41 | + # result = ifelse.(y .== 0.0f0, zero.(result), result) |
| 42 | + return -sum(result) |
| 43 | + end |
| 44 | + |
| 45 | + function loss_function(model, x, y, ps, st) |
| 46 | + y_hat, _ = model(x, ps, st) |
| 47 | + # return CrossEntropyLoss()(y_hat, y) |
| 48 | + return crossentropy(y_hat, y) |
| 49 | + end |
| 50 | + |
| 51 | + function gradient_loss_function(model, x, y, ps, st) |
| 52 | + dps = Enzyme.make_zero(ps) |
| 53 | + _, res = Enzyme.autodiff( |
| 54 | + ReverseWithPrimal, |
| 55 | + loss_function, |
| 56 | + Active, |
| 57 | + Const(model), |
| 58 | + Const(x), |
| 59 | + Const(y), |
| 60 | + Duplicated(ps, dps), |
| 61 | + Const(st), |
| 62 | + ) |
| 63 | + return res, dps |
| 64 | + end |
| 65 | + |
| 66 | + res, dps = gradient_loss_function(model, noisy, target, ps, st) |
| 67 | + |
| 68 | + compiled_gradient = Reactant.compile( |
| 69 | + gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst2) |
75 | 70 | )
|
76 |
| - return res, dps |
77 |
| -end |
78 |
| - |
79 |
| -gradient_loss_function(model, noisy, target, ps, st) |
80 |
| - |
81 |
| -compiled_gradient = @compile gradient_loss_function(cmodel, cnoisy, ctarget, cps, cst) |
82 |
| - |
83 |
| -@test length(compiled_gradient(cmodel, cnoisy, ctarget, cps, cst)) == 2 |
84 |
| - |
85 |
| -# # Training loop, using the whole data set 1000 times: |
86 |
| -# losses = [] |
87 |
| -# for epoch in 1:1_000 |
88 |
| -# for (x, y) in loader |
89 |
| -# loss, grads = Flux.withgradient(model) do m |
90 |
| -# # Evaluate model and loss inside gradient context: |
91 |
| -# y_hat = m(x) |
92 |
| -# return Flux.crossentropy(y_hat, y) |
93 |
| -# end |
94 |
| -# Flux.update!(optim, model, grads[1]) |
95 |
| -# push!(losses, loss) # logging, outside gradient context |
96 |
| -# end |
97 |
| -# end |
98 | 71 |
|
99 |
| -# optim # parameters, momenta and output have all changed |
100 |
| -# out2 = model(noisy) # first row is prob. of true, second row p(false) |
| 72 | + res_reactant, dps_reactant = compiled_gradient(cmodel, cnoisy, ctarget, cps, cst2) |
101 | 73 |
|
102 |
| -# mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far! |
| 74 | + @test res ≈ res_reactant |
| 75 | + for (dps1, dps2) in zip(fleaves(dps), fleaves(dps_reactant)) |
| 76 | + @test dps1 ≈ dps2 |
| 77 | + end |
| 78 | +end |
0 commit comments