@@ -15,6 +15,24 @@ using ..Reactant:
15
15
using ReactantCore: ReactantCore
16
16
using Functors: fmap
17
17
18
+ using Reactant_jll: Reactant_jll
19
+
20
+ function __init__ ()
21
+ if Reactant_jll. is_available ()
22
+ error_fn_ptr = @cfunction (error, Union{}, (String,))
23
+ println_fn_ptr = @cfunction (println, Nothing, (String,))
24
+
25
+ for (ptr, enzymexla_name) in
26
+ [(error_fn_ptr, :enzymexla_error ), (println_fn_ptr, :enzymexla_println )]
27
+ @ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
28
+ enzymexla_name:: Cstring , ptr:: Ptr{Cvoid}
29
+ ):: Cvoid
30
+ end
31
+ end
32
+
33
+ return nothing
34
+ end
35
+
18
36
function mlir_type (x:: Union{RNumber,RArray} ):: MLIR.IR.Type
19
37
return MLIR. IR. TensorType (collect (Int, size (x)), MLIR. IR. Type (unwrapped_eltype (x)))
20
38
end
@@ -3114,4 +3132,55 @@ end
3114
3132
end
3115
3133
end
3116
3134
3135
+ @noinline function throw (
3136
+ msg:: String ; location= mlir_stacktrace (" throw" , @__FILE__ , @__LINE__ )
3137
+ )
3138
+ mod = MLIR. IR. mmodule ()
3139
+
3140
+ sym_name = string (Reactant. TracedUtils. __lookup_unique_name_in_module (mod, " error_msg" ))
3141
+ MLIR. IR. inject! (
3142
+ sym_name,
3143
+ " llvm.mlir.global constant @$(sym_name) (\" $(msg) \" )" ;
3144
+ mod,
3145
+ location,
3146
+ verify= true ,
3147
+ )
3148
+
3149
+ if MLIR. IR. mlirIsNull (
3150
+ MLIR. API. mlirSymbolTableLookup (
3151
+ MLIR. IR. SymbolTable (MLIR. IR. Operation (mod)), " enzymexla_println"
3152
+ ),
3153
+ )
3154
+ MLIR. IR. inject! (
3155
+ " enzymexla_println" ,
3156
+ " llvm.func external @enzymexla_println(!llvm.ptr)" ;
3157
+ mod,
3158
+ location,
3159
+ )
3160
+ end
3161
+
3162
+ error_fn = string (Reactant. TracedUtils. __lookup_unique_name_in_module (mod, " error" ))
3163
+ MLIR. IR. inject! (
3164
+ error_fn,
3165
+ """
3166
+ func.func @$(error_fn) () attributes {no_inline} {
3167
+ %err_msg_ptr = llvm.mlir.addressof @$(sym_name) : !llvm.ptr
3168
+ llvm.call @enzymexla_println(%err_msg_ptr) : (!llvm.ptr) -> ()
3169
+ func.return
3170
+ }
3171
+ """ ;
3172
+ mod,
3173
+ location,
3174
+ )
3175
+
3176
+ MLIR. Dialects. enzymexla. jit_call (
3177
+ MLIR. IR. Value[];
3178
+ fn= MLIR. IR. FlatSymbolRefAttribute (error_fn),
3179
+ result_0= MLIR. IR. Type[],
3180
+ location,
3181
+ )
3182
+
3183
+ return nothing
3184
+ end
3185
+
3117
3186
end # module Ops
0 commit comments