Skip to content

Commit a839436

Browse files
committed
feat: enable for number
1 parent 7fcb9df commit a839436

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

src/Ops.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3211,18 +3211,21 @@ end
32113211
end
32123212

32133213
@noinline function gelu(
3214-
x::TracedRArray{T,N},
3214+
x::Union{TracedRArray{T,N},TracedRNumber{T}},
32153215
approximation::String;
32163216
location=mlir_stacktrace("gelu", @__FILE__, @__LINE__),
32173217
) where {T,N}
32183218
@assert approximation in ("NONE", "TANH", "SIGMOID")
3219-
return TracedRArray{T,N}(
3220-
(),
3221-
MLIR.IR.result(
3222-
enzymexla.ml_gelu(x.mlir_data; gelu_approximation=approximation, location), 1
3223-
),
3224-
size(x),
3219+
3220+
res = MLIR.IR.result(
3221+
enzymexla.ml_gelu(x.mlir_data; gelu_approximation=approximation, location), 1
32253222
)
3223+
3224+
if x isa TracedRArray
3225+
return TracedRArray{T,N}((), res, size(x))
3226+
else
3227+
return TracedRNumber{T}((), res)
3228+
end
32263229
end
32273230

32283231
end # module Ops

0 commit comments

Comments
 (0)