Skip to content

Commit 1a319c1

Browse files
committed
feat: initial throw impl
1 parent 56c07f6 commit 1a319c1

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

src/Ops.jl

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,24 @@ using ..Reactant:
1515
using ReactantCore: ReactantCore
1616
using Functors: fmap
1717

18+
using Reactant_jll: Reactant_jll
19+
20+
function __init__()
21+
if Reactant_jll.is_available()
22+
error_fn_ptr = @cfunction(error, Union{}, (String,))
23+
println_fn_ptr = @cfunction(println, Nothing, (String,))
24+
25+
for (ptr, enzymexla_name) in
26+
[(error_fn_ptr, :enzymexla_error), (println_fn_ptr, :enzymexla_println)]
27+
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
28+
enzymexla_name::Cstring, ptr::Ptr{Cvoid}
29+
)::Cvoid
30+
end
31+
end
32+
33+
return nothing
34+
end
35+
1836
function mlir_type(x::Union{RNumber,RArray})::MLIR.IR.Type
1937
return MLIR.IR.TensorType(collect(Int, size(x)), MLIR.IR.Type(unwrapped_eltype(x)))
2038
end
@@ -3114,4 +3132,55 @@ end
31143132
end
31153133
end
31163134

3135+
@noinline function throw(
3136+
msg::String; location=mlir_stacktrace("throw", @__FILE__, @__LINE__)
3137+
)
3138+
mod = MLIR.IR.mmodule()
3139+
3140+
sym_name = string(Reactant.TracedUtils.__lookup_unique_name_in_module(mod, "error_msg"))
3141+
MLIR.IR.inject!(
3142+
sym_name,
3143+
"llvm.mlir.global constant @$(sym_name)(\"$(msg)\")";
3144+
mod,
3145+
location,
3146+
verify=true,
3147+
)
3148+
3149+
if MLIR.IR.mlirIsNull(
3150+
MLIR.API.mlirSymbolTableLookup(
3151+
MLIR.IR.SymbolTable(MLIR.IR.Operation(mod)), "enzymexla_println"
3152+
),
3153+
)
3154+
MLIR.IR.inject!(
3155+
"enzymexla_println",
3156+
"llvm.func external @enzymexla_println(!llvm.ptr)";
3157+
mod,
3158+
location,
3159+
)
3160+
end
3161+
3162+
error_fn = string(Reactant.TracedUtils.__lookup_unique_name_in_module(mod, "error"))
3163+
MLIR.IR.inject!(
3164+
error_fn,
3165+
"""
3166+
func.func @$(error_fn)() attributes {no_inline} {
3167+
%err_msg_ptr = llvm.mlir.addressof @$(sym_name) : !llvm.ptr
3168+
llvm.call @enzymexla_println(%err_msg_ptr) : (!llvm.ptr) -> ()
3169+
func.return
3170+
}
3171+
""";
3172+
mod,
3173+
location,
3174+
)
3175+
3176+
MLIR.Dialects.enzymexla.jit_call(
3177+
MLIR.IR.Value[];
3178+
fn=MLIR.IR.FlatSymbolRefAttribute(error_fn),
3179+
result_0=MLIR.IR.Type[],
3180+
location,
3181+
)
3182+
3183+
return nothing
3184+
end
3185+
31173186
end # module Ops

0 commit comments

Comments
 (0)