|
1 | 1 | using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays
|
2 | 2 |
|
| 3 | +# Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the |
| 4 | +# training loop manually: |
| 5 | +function crossentropy(ŷ, y) |
| 6 | + logŷ = log.(ŷ) |
| 7 | + result = y .* logŷ |
| 8 | + return -sum(result) |
| 9 | +end |
| 10 | + |
| 11 | +function loss_function(model, x, y, ps, st) |
| 12 | + y_hat, _ = model(x, ps, st) |
| 13 | + # return CrossEntropyLoss()(y_hat, y) |
| 14 | + return crossentropy(y_hat, y) |
| 15 | +end |
| 16 | + |
| 17 | +function gradient_loss_function(model, x, y, ps, st) |
| 18 | + dps = Enzyme.make_zero(ps) |
| 19 | + _, res = Enzyme.autodiff( |
| 20 | + ReverseWithPrimal, |
| 21 | + loss_function, |
| 22 | + Active, |
| 23 | + Const(model), |
| 24 | + Const(x), |
| 25 | + Const(y), |
| 26 | + Duplicated(ps, dps), |
| 27 | + Const(st), |
| 28 | + ) |
| 29 | + return res, dps |
| 30 | +end |
| 31 | + |
3 | 32 | @testset "Lux.jl Integration" begin
|
4 | 33 | # Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
|
5 | 34 | noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32}
|
@@ -33,36 +62,6 @@ using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays
|
33 | 62 | ctarget = Reactant.ConcreteRArray(Array{Float32}(target))
|
34 | 63 | # ctarget = Reactant.to_rarray(target)
|
35 | 64 |
|
36 |
| - # Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the |
37 |
| - # training loop manually: |
38 |
| - function crossentropy(ŷ, y) |
39 |
| - logŷ = log.(ŷ) |
40 |
| - result = y .* logŷ |
41 |
| - # result = ifelse.(y .== 0.0f0, zero.(result), result) |
42 |
| - return -sum(result) |
43 |
| - end |
44 |
| - |
45 |
| - function loss_function(model, x, y, ps, st) |
46 |
| - y_hat, _ = model(x, ps, st) |
47 |
| - # return CrossEntropyLoss()(y_hat, y) |
48 |
| - return crossentropy(y_hat, y) |
49 |
| - end |
50 |
| - |
51 |
| - function gradient_loss_function(model, x, y, ps, st) |
52 |
| - dps = Enzyme.make_zero(ps) |
53 |
| - _, res = Enzyme.autodiff( |
54 |
| - ReverseWithPrimal, |
55 |
| - loss_function, |
56 |
| - Active, |
57 |
| - Const(model), |
58 |
| - Const(x), |
59 |
| - Const(y), |
60 |
| - Duplicated(ps, dps), |
61 |
| - Const(st), |
62 |
| - ) |
63 |
| - return res, dps |
64 |
| - end |
65 |
| - |
66 | 65 | res, dps = gradient_loss_function(model, noisy, target, ps, st)
|
67 | 66 |
|
68 | 67 | compiled_gradient = Reactant.compile(
|
|
0 commit comments