Skip to content

Commit d5c5cf1

Browse files
chore: apply formatting suggestion
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent ae27041 commit d5c5cf1

File tree

2 files changed

+29
-31
lines changed

2 files changed

+29
-31
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ function NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T}
2727
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
2828
end
2929

30-
3130
# TODO handle non finite cases
3231
function NNlib.softmax!(
3332
out::Reactant.TracedRArray{T,N}, x::AbstractArray; dims=1

test/nn/lux.jl

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,34 @@
11
using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays
22

3+
# Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the
4+
# training loop manually:
5+
function crossentropy(ŷ, y)
6+
logŷ = log.(ŷ)
7+
result = y .* logŷ
8+
return -sum(result)
9+
end
10+
11+
function loss_function(model, x, y, ps, st)
12+
y_hat, _ = model(x, ps, st)
13+
# return CrossEntropyLoss()(y_hat, y)
14+
return crossentropy(y_hat, y)
15+
end
16+
17+
function gradient_loss_function(model, x, y, ps, st)
18+
dps = Enzyme.make_zero(ps)
19+
_, res = Enzyme.autodiff(
20+
ReverseWithPrimal,
21+
loss_function,
22+
Active,
23+
Const(model),
24+
Const(x),
25+
Const(y),
26+
Duplicated(ps, dps),
27+
Const(st),
28+
)
29+
return res, dps
30+
end
31+
332
@testset "Lux.jl Integration" begin
433
# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
534
noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32}
@@ -33,36 +62,6 @@ using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays
3362
ctarget = Reactant.ConcreteRArray(Array{Float32}(target))
3463
# ctarget = Reactant.to_rarray(target)
3564

36-
# Lux.Exprimental.TrainState is very specialized for Lux models, so we write out the
37-
# training loop manually:
38-
function crossentropy(ŷ, y)
39-
logŷ = log.(ŷ)
40-
result = y .* logŷ
41-
# result = ifelse.(y .== 0.0f0, zero.(result), result)
42-
return -sum(result)
43-
end
44-
45-
function loss_function(model, x, y, ps, st)
46-
y_hat, _ = model(x, ps, st)
47-
# return CrossEntropyLoss()(y_hat, y)
48-
return crossentropy(y_hat, y)
49-
end
50-
51-
function gradient_loss_function(model, x, y, ps, st)
52-
dps = Enzyme.make_zero(ps)
53-
_, res = Enzyme.autodiff(
54-
ReverseWithPrimal,
55-
loss_function,
56-
Active,
57-
Const(model),
58-
Const(x),
59-
Const(y),
60-
Duplicated(ps, dps),
61-
Const(st),
62-
)
63-
return res, dps
64-
end
65-
6665
res, dps = gradient_loss_function(model, noisy, target, ps, st)
6766

6867
compiled_gradient = Reactant.compile(

0 commit comments

Comments
 (0)