Skip to content

Commit 7fcb9df

Browse files
committed
feat: use ml.gelu
1 parent dbbcb0f commit 7fcb9df

File tree

4 files changed

+36
-10
lines changed

4 files changed

+36
-10
lines changed

ext/ReactantNNlibExt/Implementations.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,14 @@ end
1010
# Without this we will never fuse the gelu into gemm
1111
if isdefined(NNlib, :gelu_tanh)
1212
function NNlib.gelu_tanh(x::TracedRNumber)
13-
α = NNlib.oftf(x, 0.044715)
14-
half = NNlib.oftf(x, 0.5)
15-
λ = sqrt(NNlib.oftf(x, 2 / pi))
16-
return x * (half * (1 + tanh* (x + α * x^3))))
13+
return Reactant.Ops.gelu(x, Reactant.NNLIB_GELU_APPROXIMATION[])
1714
end
15+
16+
NNlib.gelu_erf(x::TracedRNumber) = Reactant.Ops.gelu(x, "NONE")
1817
else
1918
# Older versions of NNlib do not have gelu_tanh (gelu refers to the tanh version)
2019
function NNlib.gelu(x::TracedRNumber)
21-
α = NNlib.oftf(x, 0.044715)
22-
half = NNlib.oftf(x, 0.5)
23-
λ = sqrt(NNlib.oftf(x, 2 / pi))
24-
return x * (half * (1 + tanh* (x + α * x^3))))
20+
return Reactant.Ops.gelu(x, Reactant.NNLIB_GELU_APPROXIMATION[])
2521
end
2622
end
2723

src/Compiler.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,7 @@ function compile_mlir!(
16231623
blas_int_width = sizeof(BLAS.BlasInt) * 8
16241624
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
16251625
blas_int_width=$blas_int_width}"
1626+
lower_enzymexla_ml_pass = "lower-enzymexla-ml"
16261627

16271628
if compile_options.optimization_passes === :all
16281629
run_pass_pipeline!(
@@ -1650,6 +1651,7 @@ function compile_mlir!(
16501651
)...,
16511652
opt_passes2,
16521653
lower_enzymexla_linalg_pass,
1654+
lower_enzymexla_ml_pass,
16531655
jit,
16541656
]
16551657
else
@@ -1674,6 +1676,7 @@ function compile_mlir!(
16741676
kern,
16751677
raise_passes,
16761678
lower_enzymexla_linalg_pass,
1679+
lower_enzymexla_ml_pass,
16771680
jit,
16781681
]
16791682
end,
@@ -1863,6 +1866,7 @@ function compile_mlir!(
18631866
)...,
18641867
opt_passes2,
18651868
lower_enzymexla_linalg_pass,
1869+
lower_enzymexla_ml_pass,
18661870
jit,
18671871
]
18681872
else
@@ -1884,6 +1888,7 @@ function compile_mlir!(
18841888
kern,
18851889
raise_passes,
18861890
lower_enzymexla_linalg_pass,
1891+
lower_enzymexla_ml_pass,
18871892
jit,
18881893
]
18891894
end,
@@ -1906,6 +1911,7 @@ function compile_mlir!(
19061911
enzyme_pass,
19071912
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
19081913
lower_enzymexla_linalg_pass,
1914+
lower_enzymexla_ml_pass,
19091915
jit,
19101916
]
19111917
else
@@ -1919,6 +1925,7 @@ function compile_mlir!(
19191925
kern,
19201926
raise_passes,
19211927
lower_enzymexla_linalg_pass,
1928+
lower_enzymexla_ml_pass,
19221929
jit,
19231930
]
19241931
end,

src/Configuration.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ scope will use the provided values.
2020
`ApproxTopK` for TPUs unless `fallback_approx_top_k_lowering` is set to `true`.
2121
- `fallback_approx_top_k_lowering`: Whether to lower `Ops.approx_top_k` to
2222
`stablehlo.top_k` if the XLA backend doesn't support `ApproxTopK`. Defaults to `true`.
23+
- `nnlib_gelu_approximation`: Controls the approximation used for `NNlib.gelu_tanh`. Can
24+
be `"TANH"` or `"SIGMOID"`. Defaults to `"SIGMOID"`.
2325
2426
### DotGeneral
2527
@@ -38,6 +40,7 @@ function with_config(
3840
convolution_precision=missing,
3941
lower_partialsort_to_approx_top_k=missing,
4042
fallback_approx_top_k_lowering=missing,
43+
nnlib_gelu_approximation=missing,
4144
)
4245
config_vars = ()
4346
dot_general_algorithm !== missing &&
@@ -58,13 +61,18 @@ function with_config(
5861
FALLBACK_APPROX_TOP_K_LOWERING => fallback_approx_top_k_lowering,
5962
)
6063
)
64+
if nnlib_gelu_approximation !== missing
65+
@assert nnlib_gelu_approximation in ("TANH", "SIGMOID") "Invalid nnlib_gelu_approximation: $nnlib_gelu_approximation. Expected \"TANH\" or \"SIGMOID\"."
66+
config_vars = (config_vars..., NNLIB_GELU_APPROXIMATION => nnlib_gelu_approximation)
67+
end
6168

6269
return ScopedValues.with(f, config_vars...)
6370
end
6471

6572
# Lower to ApproxTopK
6673
const LOWER_PARTIALSORT_TO_APPROX_TOP_K = ScopedValue(false)
6774
const FALLBACK_APPROX_TOP_K_LOWERING = ScopedValue(true)
75+
const NNLIB_GELU_APPROXIMATION = ScopedValue("SIGMOID")
6876

6977
# DotGeneral Attributes Configuration
7078
"""

src/Ops.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Julia and Reactant semantics should be considered on the higher abstractions that use these ops.
44
module Ops
55
using ..MLIR: MLIR
6-
using ..MLIR.Dialects: stablehlo, chlo, enzyme
6+
using ..MLIR.Dialects: stablehlo, chlo, enzyme, enzymexla
77
using ..Reactant:
88
Reactant,
99
TracedRArray,
@@ -3003,7 +3003,7 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
30033003
permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1))
30043004
info_shape = batch_shape
30053005

3006-
op = MLIR.Dialects.enzymexla.linalg_lu(
3006+
op = enzymexla.linalg_lu(
30073007
x.mlir_data;
30083008
output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))),
30093009
pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(pT)),
@@ -3210,4 +3210,19 @@ end
32103210
end
32113211
end
32123212

3213+
@noinline function gelu(
3214+
x::TracedRArray{T,N},
3215+
approximation::String;
3216+
location=mlir_stacktrace("gelu", @__FILE__, @__LINE__),
3217+
) where {T,N}
3218+
@assert approximation in ("NONE", "TANH", "SIGMOID")
3219+
return TracedRArray{T,N}(
3220+
(),
3221+
MLIR.IR.result(
3222+
enzymexla.ml_gelu(x.mlir_data; gelu_approximation=approximation, location), 1
3223+
),
3224+
size(x),
3225+
)
3226+
end
3227+
32133228
end # module Ops

0 commit comments

Comments
 (0)