Skip to content

Commit 336e939

Browse files
committed
test: fast_activation and fast_activation!!
1 parent 7390d5b commit 336e939

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

test/nn/luxlib.jl

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,58 @@
1-
using LuxLib, Reactant, Enzyme
1+
using LuxLib, Reactant, Enzyme, NNlib
22

33
@testset "Fused Dense" begin end
44

55
@testset "Bias Activation" begin end
66

7-
@testset "Fast Activation" begin end
7+
@testset "Fast Activation" begin
8+
# Here we are testing that fast_activation doesn't switch to the faster versions
9+
sumabs2(f, x) = sum(abs2, fast_activation(f, x))
10+
sumabs2!!(f, x) = sum(abs2, fast_activation!!(f, copy(x)))
11+
12+
function ∇sumabs2(f, x)
13+
dx = Enzyme.make_zero(x)
14+
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
15+
return dx
16+
end
17+
18+
function ∇sumabs2!!(f, x)
19+
dx = Enzyme.make_zero(x)
20+
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
21+
return dx
22+
end
23+
24+
x_act = randn(Float32, 10, 10)
25+
x_act_ca = Reactant.ConcreteRArray(x_act)
26+
27+
@testset "Activation: $act" for act in (
28+
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
29+
)
30+
f_compile = Reactant.compile(sumabs2, (act, x_act))
31+
f_compile!! = Reactant.compile(sumabs2!!, (act, x_act))
32+
33+
y_simple = sumabs2(act, x_act)
34+
y_simple!! = sumabs2!!(act, x_act)
35+
y_compile = f_compile(act, x_act_ca)
36+
y_compile!! = f_compile!!(act, x_act_ca)
37+
38+
@test y_simple y_compile
39+
@test y_simple!! y_compile!!
40+
41+
∂x_enz = Enzyme.make_zero(x_act)
42+
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
43+
44+
∂x_enz!! = Enzyme.make_zero(x_act)
45+
Enzyme.autodiff(Reverse, sumabs2!!, Active, Const(act), Duplicated(x_act, ∂x_enz!!))
46+
47+
∇sumabs2_compiled = Reactant.compile(∇sumabs2, (act, x_act_ca))
48+
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
49+
50+
∇sumabs2!!_compiled = Reactant.compile(∇sumabs2!!, (act, x_act_ca))
51+
∂x_compile!! = ∇sumabs2!!_compiled(act, x_act_ca)
52+
53+
@test ∂x_enz ∂x_compile broken=(act === gelu)
54+
@test ∂x_enz!! ∂x_compile!! broken=(act === gelu)
55+
end
56+
end
857

958
@testset "Fused Conv" begin end

test/nn/nnlib.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ using NNlib, Reactant, Enzyme
2828
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
2929

3030
@test y_simple y_compile
31+
# Mathematically the gelu definition here is slightly different from the one in NNlib
32+
@test ∂x_enz ∂x_compile broken=(act === gelu)
3133
end
3234
end
3335

0 commit comments

Comments
 (0)