Skip to content

Commit 350c483

Browse files
committed
feat: overlay error
1 parent 5b2664c commit 350c483

File tree

2 files changed

+18
-20
lines changed

2 files changed

+18
-20
lines changed

src/Ops.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@ using Functors: fmap
1717

1818
using Reactant_jll: Reactant_jll
1919

20+
function unsafe_print(x)
21+
print(unsafe_string(x))
22+
return nothing
23+
end
24+
2025
function __init__()
2126
if Reactant_jll.is_available()
22-
error_fn_ptr = @cfunction(error, Union{}, (String,))
23-
println_fn_ptr = @cfunction(println, Nothing, (String,))
27+
print_fn_ptr = @cfunction(unsafe_print, Nothing, (Cstring,))
2428

25-
for (ptr, enzymexla_name) in
26-
[(error_fn_ptr, :enzymexla_error), (println_fn_ptr, :enzymexla_println)]
29+
for (ptr, enzymexla_name) in [(print_fn_ptr, :enzymexla_print)]
2730
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
2831
enzymexla_name::Cstring, ptr::Ptr{Cvoid}
2932
)::Cvoid
@@ -3146,27 +3149,13 @@ end
31463149
verify=true,
31473150
)
31483151

3149-
if MLIR.IR.mlirIsNull(
3150-
MLIR.API.mlirSymbolTableLookup(
3151-
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), "enzymexla_println"
3152-
),
3153-
)
3154-
MLIR.IR.inject!(
3155-
"enzymexla_println",
3156-
"llvm.func external @enzymexla_println(!llvm.ptr)";
3157-
mod,
3158-
location,
3159-
)
3160-
end
3161-
31623152
error_fn = string(Reactant.TracedUtils.__lookup_unique_name_in_module(mod, "error"))
31633153
MLIR.IR.inject!(
31643154
error_fn,
31653155
"""
3166-
func.func @$(error_fn)() attributes {no_inline} {
3156+
func.func @$(error_fn)() -> (!llvm.ptr) {
31673157
%err_msg_ptr = llvm.mlir.addressof @$(sym_name) : !llvm.ptr
3168-
llvm.call @enzymexla_println(%err_msg_ptr) : (!llvm.ptr) -> ()
3169-
func.return
3158+
return %err_msg_ptr : !llvm.ptr
31703159
}
31713160
""";
31723161
mod,

src/Overlay.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,12 @@ end
178178
return Base.inferencebarrier(Base.any)(f, x, dims)
179179
end
180180
end
181+
182+
# Exception Handling
183+
184+
## Ideally we would want to overlay `throw` but that is built-in function and overlaying it
185+
## doesn't work
186+
## @reactant_overlay @noinline Base.throw(err) = Ops.throw(sprint(showerror, err))
187+
188+
@reactant_overlay @noinline Base.error(s::AbstractString) = Ops.throw(String(s))
189+
@reactant_overlay @noinline Base.error(s::Vararg{Any,N}) where {N} = Ops.throw(string(s...))

0 commit comments

Comments
 (0)