Skip to content

Commit 72794b1

Browse files
committed
test: fused bias activation
1 parent 7361409 commit 72794b1

File tree

2 files changed

+74
-4
lines changed

2 files changed

+74
-4
lines changed

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ end
8989

9090
Base.size(x::TracedRArray) = x.shape
9191

92-
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray((), A.mlir_data, size(A))
92+
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))
9393

9494
function Base.similar(x::TracedRArray{T,N}, ::Type{T2}) where {T,N,T2}
9595
return TracedRArray{T2,N}((), nothing, size(x))

test/nn/luxlib.jl

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,76 @@ using LuxLib, Reactant, Enzyme, NNlib
33
@testset "Fused Dense" begin end
44

55
@testset "Bias Activation" begin
6+
biasact(act, x, b) = bias_activation(act, x, b)
7+
sumabs2biasact(act, x, b) = sum(abs2, biasact(act, x, b))
8+
biasact!!(act, x, b) = bias_activation!!(act, copy(x), b)
9+
sumabs2biasact!!(act, x, b) = sum(abs2, biasact!!(act, x, b))
10+
11+
function ∇biasact(act, x, b)
12+
dx = Enzyme.make_zero(x)
13+
db = Enzyme.make_zero(b)
14+
Enzyme.autodiff(
15+
Reverse,
16+
sumabs2biasact,
17+
Active,
18+
Const(act),
19+
Duplicated(x, dx),
20+
Duplicated(b, db),
21+
)
22+
return dx, db
23+
end
24+
25+
function ∇biasact!!(act, x, b)
26+
dx = Enzyme.make_zero(x)
27+
db = Enzyme.make_zero(b)
28+
Enzyme.autodiff(
29+
Reverse,
30+
sumabs2biasact!!,
31+
Active,
32+
Const(act),
33+
Duplicated(x, dx),
34+
Duplicated(b, db),
35+
)
36+
return dx, db
37+
end
38+
39+
@testset for act in (identity, relu, sigmoid, tanh, gelu)
40+
x = randn(Float32, 10, 10)
41+
b = randn(Float32, 10)
42+
43+
x_ra = Reactant.ConcreteRArray(x)
44+
b_ra = Reactant.ConcreteRArray(b)
45+
46+
f_compile = Reactant.compile(biasact, (act, x_ra, b_ra))
47+
f_compile!! = Reactant.compile(biasact!!, (act, x_ra, b_ra))
48+
49+
y_simple = biasact(act, x, b)
50+
y_simple!! = biasact!!(act, x, b)
51+
y_compile = f_compile(act, x_ra, b_ra)
52+
y_compile!! = f_compile!!(act, x_ra, b_ra)
53+
54+
@test y_simple y_compile broken = (act === gelu)
55+
@test y_simple!! y_compile!! broken = (act === gelu)
56+
57+
# FIXME: Seems broken currently
58+
@testset "Enzyme: bias_activation" begin
59+
∂x_enz, ∂b_enz = ∇biasact(act, x, b)
60+
∇biasact_compiled = Reactant.compile(∇biasact, (act, x_ra, b_ra))
61+
∂x_compile, ∂b_compile = ∇biasact_compiled(act, x_ra, b_ra)
62+
63+
@test ∂x_enz ∂x_compile broken = (act === gelu)
64+
@test ∂b_enz ∂b_compile broken = (act === gelu)
65+
end
66+
67+
@testset "Enzyme: bias_activation!!" begin
68+
∂x_enz!!, ∂b_enz!! = ∇biasact!!(act, x, b)
69+
∇biasact!!_compiled = Reactant.compile(∇biasact!!, (act, x_ra, b_ra))
70+
∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled(act, x_ra, b_ra)
71+
72+
@test ∂x_enz!! ∂x_compile!! broken = (act === gelu)
73+
@test ∂b_enz!! ∂b_compile!! broken = (act === gelu)
74+
end
75+
end
676
end
777

878
@testset "Fast Activation" begin
@@ -18,7 +88,7 @@ end
1888

1989
function ∇sumabs2!!(f, x)
2090
dx = Enzyme.make_zero(x)
21-
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
91+
Enzyme.autodiff(Reverse, sumabs2!!, Active, Const(f), Duplicated(x, dx))
2292
return dx
2393
end
2494

@@ -28,8 +98,8 @@ end
2898
@testset "Activation: $act" for act in (
2999
identity, relu, sigmoid, tanh, tanh_fast, sigmoid_fast, gelu, abs2
30100
)
31-
f_compile = Reactant.compile(sumabs2, (act, x_act))
32-
f_compile!! = Reactant.compile(sumabs2!!, (act, x_act))
101+
f_compile = Reactant.compile(sumabs2, (act, x_act_ca))
102+
f_compile!! = Reactant.compile(sumabs2!!, (act, x_act_ca))
33103

34104
y_simple = sumabs2(act, x_act)
35105
y_simple!! = sumabs2!!(act, x_act)

0 commit comments

Comments
 (0)