Skip to content

Commit 28849c5

Browse files
committed
test: fused dense
1 parent 72794b1 commit 28849c5

File tree

1 file changed

+56
-4
lines changed

1 file changed

+56
-4
lines changed

test/nn/luxlib.jl

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,59 @@
11
using LuxLib, Reactant, Enzyme, NNlib
22

3-
@testset "Fused Dense" begin end
3+
@testset "Fused Dense" begin
4+
sumabs2fuseddense(act, weight, x, bias) =
5+
sum(abs2, fused_dense_bias_activation(act, weight, x, bias))
6+
7+
function ∇fuseddense(act, weight, x, bias)
8+
dw = Enzyme.make_zero(weight)
9+
dx = Enzyme.make_zero(x)
10+
db = bias === nothing ? nothing : Enzyme.make_zero(bias)
11+
b_dup = bias === nothing ? Const(bias) : Duplicated(bias, db)
12+
Enzyme.autodiff(
13+
Reverse,
14+
sumabs2fuseddense,
15+
Active,
16+
Const(act),
17+
Duplicated(weight, dw),
18+
Duplicated(x, dx),
19+
b_dup,
20+
)
21+
return dw, dx, db
22+
end
23+
24+
@testset for act in (identity, relu, sigmoid, tanh, gelu), has_bias in (true, false)
25+
weight = randn(Float32, 9, 10)
26+
x = randn(Float32, 10, 12)
27+
bias = has_bias ? randn(Float32, 9) : nothing
28+
29+
weight_ra = Reactant.ConcreteRArray(weight)
30+
x_ra = Reactant.ConcreteRArray(x)
31+
bias_ra = Reactant.to_rarray(bias)
32+
33+
f_compile = Reactant.compile(
34+
fused_dense_bias_activation, (act, weight_ra, x_ra, bias_ra)
35+
)
36+
37+
y_res = fused_dense_bias_activation(act, weight, x, bias)
38+
y_compile = f_compile(act, weight_ra, x_ra, bias_ra)
39+
40+
@test y_res y_compile broken = (act === gelu)
41+
42+
@testset "Enzyme: fused_dense_bias_activation" begin
43+
dw, dx, db = ∇fuseddense(act, weight, x, bias)
44+
∇fuseddense_compiled = Reactant.compile(
45+
∇fuseddense, (act, weight_ra, x_ra, bias_ra)
46+
)
47+
dw_compile, dx_compile, db_compile = ∇fuseddense_compiled(
48+
act, weight_ra, x_ra, bias_ra
49+
)
50+
51+
@test dw dw_compile broken = (act === gelu)
52+
@test dx dx_compile broken = (act === gelu)
53+
has_bias && @test db db_compile broken = (act === gelu)
54+
end
55+
end
56+
end
457

558
@testset "Bias Activation" begin
659
biasact(act, x, b) = bias_activation(act, x, b)
@@ -54,7 +107,6 @@ using LuxLib, Reactant, Enzyme, NNlib
54107
@test y_simple y_compile broken = (act === gelu)
55108
@test y_simple!! y_compile!! broken = (act === gelu)
56109

57-
# FIXME: Seems broken currently
58110
@testset "Enzyme: bias_activation" begin
59111
∂x_enz, ∂b_enz = ∇biasact(act, x, b)
60112
∇biasact_compiled = Reactant.compile(∇biasact, (act, x_ra, b_ra))
@@ -106,8 +158,8 @@ end
106158
y_compile = f_compile(act, x_act_ca)
107159
y_compile!! = f_compile!!(act, x_act_ca)
108160

109-
@test y_simple y_compile
110-
@test y_simple!! y_compile!!
161+
@test y_simple y_compile broken = (act === gelu)
162+
@test y_simple!! y_compile!! broken = (act === gelu)
111163

112164
∂x_enz = Enzyme.make_zero(x_act)
113165
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))

0 commit comments

Comments
 (0)