@@ -17,13 +17,16 @@ using Functors: fmap
17
17
18
18
using Reactant_jll: Reactant_jll
19
19
20
+ function unsafe_print (x)
21
+ print (unsafe_string (x))
22
+ return nothing
23
+ end
24
+
20
25
function __init__ ()
21
26
if Reactant_jll. is_available ()
22
- error_fn_ptr = @cfunction (error, Union{}, (String,))
23
- println_fn_ptr = @cfunction (println, Nothing, (String,))
27
+ print_fn_ptr = @cfunction (unsafe_print, Nothing, (Cstring,))
24
28
25
- for (ptr, enzymexla_name) in
26
- [(error_fn_ptr, :enzymexla_error ), (println_fn_ptr, :enzymexla_println )]
29
+ for (ptr, enzymexla_name) in [(print_fn_ptr, :enzymexla_print )]
27
30
@ccall MLIR. API. mlir_c. EnzymeJaXMapSymbol (
28
31
enzymexla_name:: Cstring , ptr:: Ptr{Cvoid}
29
32
):: Cvoid
@@ -3146,27 +3149,13 @@ end
3146
3149
verify= true ,
3147
3150
)
3148
3151
3149
- if MLIR. IR. mlirIsNull (
3150
- MLIR. API. mlirSymbolTableLookup (
3151
- MLIR. IR. SymbolTable (MLIR. IR. Operation (mod)), " enzymexla_println"
3152
- ),
3153
- )
3154
- MLIR. IR. inject! (
3155
- " enzymexla_println" ,
3156
- " llvm.func external @enzymexla_println(!llvm.ptr)" ;
3157
- mod,
3158
- location,
3159
- )
3160
- end
3161
-
3162
3152
error_fn = string (Reactant. TracedUtils. __lookup_unique_name_in_module (mod, " error" ))
3163
3153
MLIR. IR. inject! (
3164
3154
error_fn,
3165
3155
"""
3166
- func.func @$(error_fn) () attributes {no_inline} {
3156
+ func.func @$(error_fn) () -> (!llvm.ptr) {
3167
3157
%err_msg_ptr = llvm.mlir.addressof @$(sym_name) : !llvm.ptr
3168
- llvm.call @enzymexla_println(%err_msg_ptr) : (!llvm.ptr) -> ()
3169
- func.return
3158
+ return %err_msg_ptr : !llvm.ptr
3170
3159
}
3171
3160
""" ;
3172
3161
mod,
0 commit comments