Skip to content

Commit a89ba0c

Browse files
authored
Merge pull request #110 from EnzymeAD/ap/nn_tests
tests: more comprehensive NN testing
2 parents ea88cb0 + 61b0c6d commit a89ba0c

File tree

14 files changed

+452
-273
lines changed

14 files changed

+452
-273
lines changed

.github/workflows/benchmark_pr.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ jobs:
3636
echo $PATH
3737
ls -l ~/.julia/bin
3838
mkdir results
39-
benchpkg ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.repository.default_branch}}" --output-dir=results/ -s="benchmark/benchmarks.jl" --tune --add="Enzyme,Lux,Boltz,Random"
40-
env:
41-
JULIA_PKG_SERVER: ""
39+
benchpkg ${{ steps.extract-package-name.outputs.package_name }} --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" --url=${{ github.event.repository.clone_url }} --bench-on="${{github.event.pull_request.head.sha}}" --output-dir=results/ --tune --exeflags="-O3 --threads=auto"
4240
- name: Create plots from benchmarks
4341
run: |
4442
mkdir -p plots

benchmark/Project.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[deps]
2+
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
3+
Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
4+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
5+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
6+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7+
8+
[compat]
9+
BenchmarkTools = "1.5"
10+
Boltz = "1"
11+
Enzyme = "0.13"
12+
Lux = "1.1"
13+
Random = "1.10"
14+
julia = "1.10"

benchmark/benchmarks.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
# To run:
2-
# using PkgBenchmark, Reactant
3-
# result = benchmarkpkg(KernelAbstractions)
4-
# export_markdown("benchmark/perf.md", result)
5-
6-
# Note: if you change this file you will need to delete an regenerate tune.json
7-
# Your "v1.x" environment needs to have BenchmarkTools and PkgBenchmark installed.
8-
91
using BenchmarkTools
102
using Reactant
113
using Enzyme

ext/ReactantNNlibExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ end
2121

2222
NNlib.relu(x::Reactant.TracedRArray{T,0}) where {T} = max(x, zero(T))
2323

24-
NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T} = x * sigmoid(T(1.702) * x)
24+
function NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T}
25+
α = T(0.044715)
26+
λλ = T((8 / π))
27+
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
28+
end
2529

2630
# TODO handle non finite cases
2731
function NNlib.softmax!(

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ end
8989

9090
Base.size(x::TracedRArray) = x.shape
9191

92-
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray((), A.mlir_data, size(A))
92+
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))
9393

9494
function Base.similar(x::TracedRArray{T,N}, ::Type{T2}) where {T,N,T2}
9595
return TracedRArray{T2,N}((), nothing, size(x))

test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
33
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
44
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
5+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
56
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
67
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
8+
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
79
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
810
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
911
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1012
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
1113
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1214
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
15+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
1316
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1417
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/bcast.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,37 +57,6 @@ function test()
5757
end
5858
test()
5959

60-
@testset "Activation Functions" begin
61-
sumabs2(f, x) = sum(abs2, f.(x))
62-
63-
function ∇sumabs2(f, x)
64-
dx = Enzyme.make_zero(x)
65-
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
66-
return dx
67-
end
68-
69-
x_act = randn(Float32, 10, 10)
70-
x_act_ca = Reactant.ConcreteRArray(x_act)
71-
72-
@testset "Activation: $act" for act in (
73-
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
74-
)
75-
f_compile = @compile sumabs2(act, x_act)
76-
77-
y_simple = sumabs2(act, x_act)
78-
y_compile = f_compile(act, x_act_ca)
79-
80-
∂x_enz = Enzyme.make_zero(x_act)
81-
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
82-
83-
∇sumabs2_compiled = @compile ∇sumabs2(act, x_act_ca)
84-
85-
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
86-
87-
@test y_simple y_compile
88-
end
89-
end
90-
9160
@testset "ConcreteRArray broadcasting" begin
9261
x = ones(10, 10)
9362
y = ones(10, 10)

test/nn.jl

Lines changed: 0 additions & 116 deletions
This file was deleted.

test/nn/flux.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using Reactant, Flux
2+
3+
@testset "Flux.jl Integration" begin
4+
# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
5+
noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32}
6+
truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool}
7+
8+
# Define our model, a multi-layer perceptron with one hidden layer of size 3:
9+
model = Chain(
10+
Dense(2 => 3, tanh), # activation function inside layer
11+
BatchNorm(3),
12+
Dense(3 => 2),
13+
softmax,
14+
)
15+
16+
origout = model(noisy)
17+
18+
cmodel = Reactant.to_rarray(model)
19+
cnoisy = Reactant.ConcreteRArray(noisy)
20+
21+
f = Reactant.compile((a, b) -> a(b), (cmodel, cnoisy))
22+
23+
comp = f(cmodel, cnoisy)
24+
@test origout comp
25+
end

test/nn/lux.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using Reactant, Lux, Random, Statistics, Enzyme, Functors, OneHotArrays
2+
3+
function crossentropy(ŷ, y)
4+
logŷ = log.(ŷ)
5+
result = y .* logŷ
6+
return -sum(result)
7+
end
8+
9+
function loss_function(model, x, y, ps, st)
10+
y_hat, _ = model(x, ps, st)
11+
# return CrossEntropyLoss()(y_hat, y) # <-- needs handling of xlogx xlogy from LuxOps
12+
return crossentropy(y_hat, y)
13+
end
14+
15+
function gradient_loss_function(model, x, y, ps, st)
16+
dps = Enzyme.make_zero(ps)
17+
_, res = Enzyme.autodiff(
18+
ReverseWithPrimal,
19+
loss_function,
20+
Active,
21+
Const(model),
22+
Const(x),
23+
Const(y),
24+
Duplicated(ps, dps),
25+
Const(st),
26+
)
27+
return res, dps
28+
end
29+
30+
@testset "Lux.jl Integration" begin
31+
# Generate some data for the XOR problem: vectors of length 2, as columns of a matrix:
32+
noisy = rand(Float32, 2, 1000) # 2×1000 Matrix{Float32}
33+
truth = [xor(col[1] > 0.5, col[2] > 0.5) for col in eachcol(noisy)] # 1000-element Vector{Bool}
34+
35+
# Define our model, a multi-layer perceptron with one hidden layer of size 3:
36+
model = Lux.Chain(
37+
Lux.Dense(2 => 3, tanh), # activation function inside layer
38+
Lux.BatchNorm(3, sigmoid),
39+
Lux.Dense(3 => 2),
40+
softmax,
41+
)
42+
ps, st = Lux.setup(Xoshiro(123), model)
43+
44+
origout, _ = model(noisy, ps, Lux.testmode(st))
45+
46+
cmodel = Reactant.to_rarray(model)
47+
cps = Reactant.to_rarray(ps)
48+
cst = Reactant.to_rarray(Lux.testmode(st))
49+
cst2 = Reactant.to_rarray(st)
50+
cnoisy = Reactant.ConcreteRArray(noisy)
51+
52+
f = Reactant.compile((a, b, c, d) -> first(a(b, c, d)), (cmodel, cnoisy, cps, cst))
53+
54+
comp = f(cmodel, cnoisy, cps, cst)
55+
56+
@test comp origout atol = 1e-5 rtol = 1e-2
57+
58+
target = onehotbatch(truth, [true, false]) # 2×1000 OneHotMatrix
59+
60+
ctarget = Reactant.ConcreteRArray(Array{Float32}(target))
61+
# ctarget = Reactant.to_rarray(target)
62+
63+
res, dps = gradient_loss_function(model, noisy, target, ps, st)
64+
65+
compiled_gradient = Reactant.compile(
66+
gradient_loss_function, (cmodel, cnoisy, ctarget, cps, cst2)
67+
)
68+
69+
res_reactant, dps_reactant = compiled_gradient(cmodel, cnoisy, ctarget, cps, cst2)
70+
71+
@test res res_reactant atol = 1e-5 rtol = 1e-2
72+
for (dps1, dps2) in zip(fleaves(dps), fleaves(dps_reactant))
73+
@test dps1 dps2 atol = 1e-5 rtol = 1e-2
74+
end
75+
end

0 commit comments

Comments
 (0)