Skip to content

Commit ac3c512

Browse files
committed
test: move the tests around a bit
1 parent da4e5fd commit ac3c512

File tree

3 files changed

+51
-31
lines changed

3 files changed

+51
-31
lines changed

test/bcast.jl

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -57,37 +57,6 @@ function test()
5757
end
5858
test()
5959

60-
@testset "Activation Functions" begin
61-
sumabs2(f, x) = sum(abs2, f.(x))
62-
63-
function ∇sumabs2(f, x)
64-
dx = Enzyme.make_zero(x)
65-
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
66-
return dx
67-
end
68-
69-
x_act = randn(Float32, 10, 10)
70-
x_act_ca = Reactant.ConcreteRArray(x_act)
71-
72-
@testset "Activation: $act" for act in (
73-
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
74-
)
75-
f_compile = @compile sumabs2(act, x_act)
76-
77-
y_simple = sumabs2(act, x_act)
78-
y_compile = f_compile(act, x_act_ca)
79-
80-
∂x_enz = Enzyme.make_zero(x_act)
81-
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
82-
83-
∇sumabs2_compiled = @compile ∇sumabs2(act, x_act_ca)
84-
85-
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
86-
87-
@test y_simple y_compile
88-
end
89-
end
90-
9160
@testset "ConcreteRArray broadcasting" begin
9261
x = ones(10, 10)
9362
y = ones(10, 10)

test/nn/luxlib.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using LuxLib, Reactant, Enzyme
2+
3+
@testset "Fused Dense" begin
4+
end
5+
6+
@testset "Bias Activation" begin
7+
end
8+
9+
@testset "Fast Activation" begin
10+
end
11+
12+
@testset "Fused Conv" begin
13+
end

test/nn/nnlib.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using NNlib, Reactant, Enzyme
2+
3+
@testset "Activation Functions" begin
4+
sumabs2(f, x) = sum(abs2, f.(x))
5+
6+
function ∇sumabs2(f, x)
7+
dx = Enzyme.make_zero(x)
8+
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
9+
return dx
10+
end
11+
12+
x_act = randn(Float32, 10, 10)
13+
x_act_ca = Reactant.ConcreteRArray(x_act)
14+
15+
@testset "Activation: $act" for act in (
16+
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
17+
)
18+
f_compile = Reactant.compile(sumabs2, (act, x_act))
19+
20+
y_simple = sumabs2(act, x_act)
21+
y_compile = f_compile(act, x_act_ca)
22+
23+
∂x_enz = Enzyme.make_zero(x_act)
24+
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
25+
26+
∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca))
27+
28+
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
29+
30+
@test y_simple y_compile
31+
end
32+
end
33+
34+
@testset "Pooling" begin
35+
end
36+
37+
@testset "Convolution" begin
38+
end

0 commit comments

Comments
 (0)