diff --git a/Project.toml b/Project.toml index 4c00538b75..29f46e882f 100644 --- a/Project.toml +++ b/Project.toml @@ -90,7 +90,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.12" -Reactant_jll = "0.0.198" +Reactant_jll = "0.0.199" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/src/Ops.jl b/src/Ops.jl index 0f0d9b99dd..d4f6f33626 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -15,6 +15,27 @@ using ..Reactant: using ReactantCore: ReactantCore using Functors: fmap +using Reactant_jll: Reactant_jll + +function unsafe_print(x) + print(unsafe_string(x)) + return nothing +end + +function __init__() + if Reactant_jll.is_available() + print_fn_ptr = @cfunction(unsafe_print, Nothing, (Cstring,)) + + for (ptr, enzymexla_name) in [(print_fn_ptr, :enzymexla_print)] + @ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol( + enzymexla_name::Cstring, ptr::Ptr{Cvoid} + )::Cvoid + end + end + + return nothing +end + function mlir_type(x::Union{RNumber,RArray})::MLIR.IR.Type return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x))) end @@ -3114,4 +3135,41 @@ end end end +@noinline function throw( + msg::String; location=mlir_stacktrace("throw", @__FILE__, @__LINE__) +) + mod = MLIR.IR.mmodule() + + sym_name = string(Reactant.TracedUtils.__lookup_unique_name_in_module(mod, "error_msg")) + MLIR.IR.inject!( + sym_name, + "llvm.mlir.global constant @$(sym_name)(\"$(msg)\")"; + mod, + location, + verify=true, + ) + + error_fn = string(Reactant.TracedUtils.__lookup_unique_name_in_module(mod, "error")) + MLIR.IR.inject!( + error_fn, + """ + func.func @$(error_fn)() -> (!llvm.ptr) { + %err_msg_ptr = llvm.mlir.addressof @$(sym_name) : !llvm.ptr + return %err_msg_ptr : !llvm.ptr + } + """; + mod, + location, + ) + + MLIR.Dialects.enzymexla.jit_call( + MLIR.IR.Value[]; + fn=MLIR.IR.FlatSymbolRefAttribute(error_fn), + result_0=MLIR.IR.Type[], + location, + ) + + return nothing +end + end # module Ops diff --git a/src/Overlay.jl b/src/Overlay.jl index 3f1740f62e..44f520437c 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -178,3 +178,12 @@ end return Base.inferencebarrier(Base.any)(f, x, dims) end end + +# Exception Handling + +## Ideally we would want to overlay `throw` but that is built-in function and overlaying it +## doesn't work +## @reactant_overlay @noinline Base.throw(err) = Ops.throw(sprint(showerror, err)) + +@reactant_overlay @noinline Base.error(s::AbstractString) = Ops.throw(String(s)) +@reactant_overlay @noinline Base.error(s::Vararg{Any,N}) where {N} = Ops.throw(string(s...)) diff --git a/test/callback.jl b/test/callback.jl new file mode 100644 index 0000000000..29b00ea838 --- /dev/null +++ b/test/callback.jl @@ -0,0 +1,16 @@ +using Reactant + +function fn(x) + error("This should error at runtime") + return x .+ 1 +end + +@testset "error" begin + x = Reactant.to_rarray(ones(4)) + + hlo = repr(@code_hlo fn(x)) + @test contains(hlo, "stablehlo.custom_call") + + fn_compiled = @compile fn(x) + @test_throws Reactant.XLA.ReactantInternalError fn_compiled(x) +end diff --git a/test/runtests.jl b/test/runtests.jl index 411cf443ea..bbac2f230e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Cluster Detection" include("cluster_detector.jl") @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") + @safetestset "Callbacks" include("callback.jl") end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"