Skip to content

Commit ae27041

Browse files
committed
fix: gelu implementation
1 parent 13f740b commit ae27041

File tree

3 files changed

+22
-18
lines changed

3 files changed

+22
-18
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ end
2121

2222
NNlib.relu(x::Reactant.TracedRArray{T,0}) where {T} = max(x, zero(T))
2323

24-
NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T} = x * sigmoid(T(1.702) * x)
24+
function NNlib.gelu(x::Reactant.TracedRArray{T,0}) where {T}
25+
α = T(0.044715)
26+
λλ = T((8 / π))
27+
return x * sigmoid(λλ * x * muladd(x^2, α, one(T)))
28+
end
29+
2530

2631
# TODO handle non finite cases
2732
function NNlib.softmax!(

test/nn/luxlib.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ using LuxLib, Reactant, Enzyme, NNlib
3737
y_res = fused_dense_bias_activation(act, weight, x, bias)
3838
y_compile = f_compile(act, weight_ra, x_ra, bias_ra)
3939

40-
@test y_res y_compile broken = (act === gelu)
40+
@test y_res y_compile
4141

4242
@testset "Enzyme: fused_dense_bias_activation" begin
4343
dw, dx, db = ∇fuseddense(act, weight, x, bias)
@@ -48,9 +48,9 @@ using LuxLib, Reactant, Enzyme, NNlib
4848
act, weight_ra, x_ra, bias_ra
4949
)
5050

51-
@test dw dw_compile broken = (act === gelu)
52-
@test dx dx_compile broken = (act === gelu)
53-
has_bias && @test db db_compile broken = (act === gelu)
51+
@test dw dw_compile
52+
@test dx dx_compile
53+
has_bias && @test db db_compile
5454
end
5555
end
5656
end
@@ -104,25 +104,25 @@ end
104104
y_compile = f_compile(act, x_ra, b_ra)
105105
y_compile!! = f_compile!!(act, x_ra, b_ra)
106106

107-
@test y_simple y_compile broken = (act === gelu)
108-
@test y_simple!! y_compile!! broken = (act === gelu)
107+
@test y_simple y_compile
108+
@test y_simple!! y_compile!!
109109

110110
@testset "Enzyme: bias_activation" begin
111111
∂x_enz, ∂b_enz = ∇biasact(act, x, b)
112112
∇biasact_compiled = Reactant.compile(∇biasact, (act, x_ra, b_ra))
113113
∂x_compile, ∂b_compile = ∇biasact_compiled(act, x_ra, b_ra)
114114

115-
@test ∂x_enz ∂x_compile broken = (act === gelu)
116-
@test ∂b_enz ∂b_compile broken = (act === gelu)
115+
@test ∂x_enz ∂x_compile
116+
@test ∂b_enz ∂b_compile
117117
end
118118

119119
@testset "Enzyme: bias_activation!!" begin
120120
∂x_enz!!, ∂b_enz!! = ∇biasact!!(act, x, b)
121121
∇biasact!!_compiled = Reactant.compile(∇biasact!!, (act, x_ra, b_ra))
122122
∂x_compile!!, ∂b_compile!! = ∇biasact!!_compiled(act, x_ra, b_ra)
123123

124-
@test ∂x_enz!! ∂x_compile!! broken = (act === gelu)
125-
@test ∂b_enz!! ∂b_compile!! broken = (act === gelu)
124+
@test ∂x_enz!! ∂x_compile!!
125+
@test ∂b_enz!! ∂b_compile!!
126126
end
127127
end
128128
end
@@ -158,8 +158,8 @@ end
158158
y_compile = f_compile(act, x_act_ca)
159159
y_compile!! = f_compile!!(act, x_act_ca)
160160

161-
@test y_simple y_compile broken = (act === gelu)
162-
@test y_simple!! y_compile!! broken = (act === gelu)
161+
@test y_simple y_compile
162+
@test y_simple!! y_compile!!
163163

164164
∂x_enz = Enzyme.make_zero(x_act)
165165
Enzyme.autodiff(Reverse, sumabs2, Active, Const(act), Duplicated(x_act, ∂x_enz))
@@ -173,8 +173,8 @@ end
173173
∇sumabs2!!_compiled = Reactant.compile(∇sumabs2!!, (act, x_act_ca))
174174
∂x_compile!! = ∇sumabs2!!_compiled(act, x_act_ca)
175175

176-
@test ∂x_enz ∂x_compile broken = (act === gelu)
177-
@test ∂x_enz!! ∂x_compile!! broken = (act === gelu)
176+
@test ∂x_enz ∂x_compile
177+
@test ∂x_enz!! ∂x_compile!!
178178
end
179179
end
180180

@@ -207,7 +207,7 @@ end
207207
)
208208
luxlib_res = fused_conv_bias_activation(act, weight, x, bias, conv_dims)
209209

210-
@test reactant_res luxlib_res broken = (act === gelu)
210+
@test reactant_res luxlib_res
211211
end
212212

213213
# TODO: test for gradients

test/nn/nnlib.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ using NNlib, Reactant, Enzyme
2828
∂x_compile = ∇sumabs2_compiled(act, x_act_ca)
2929

3030
@test y_simple y_compile
31-
# Mathematically the gelu definition here is slightly different from the one in NNlib
32-
@test ∂x_enz ∂x_compile broken = (act === gelu)
31+
@test ∂x_enz ∂x_compile
3332
end
3433
end
3534

0 commit comments

Comments
 (0)