@@ -316,7 +316,7 @@ function certain_error()
316316 )
317317end
318318
319- function rewrite_inst (inst, ir, interp, RT, guaranteed_error)
319+ function rewrite_inst (inst, ir:: CC.IRCode , interp, RT, guaranteed_error)
320320 if Meta. isexpr (inst, :call )
321321 # Even if type unstable we do not want (or need) to replace intrinsic
322322 # calls or builtins with our version.
@@ -532,7 +532,7 @@ const DEBUG_INTERP = Ref(false)
532532# to Any if our interpreter would change the return type of any result.
533533# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
534534# screws up type inference after this (TODO this should be fixed).
535- function rewrite_insts! (ir, interp, guaranteed_error)
535+ function rewrite_insts! (ir:: CC.IRCode , interp, guaranteed_error)
536536 any_changed = false
537537 for (i, inst) in enumerate (ir. stmts)
538538 # Explicitly skip any code which returns Union{} so that we throw the error
@@ -839,11 +839,11 @@ function call_with_reactant_generator(
839839 if DEBUG_INTERP[]
840840 safe_print (" code_info" , code_info)
841841 end
842-
842+ # @lk code_info oc
843843 return code_info
844844end
845845
846- @eval function call_with_reactant ($ REDUB_ARGUMENTS_NAME... )
846+ @eval function call_with_reactant0 ($ REDUB_ARGUMENTS_NAME... )
847847 $ (Expr (:meta , :generated_only ))
848848 return $ (Expr (:meta , :generated , call_with_reactant_generator))
849849end
@@ -854,3 +854,109 @@ end
854854nmantissa (:: Type{Float16} ) = 10
855855nmantissa (:: Type{Float32} ) = 23
856856nmantissa (:: Type{Float64} ) = 52
857+
858+ using GPUCompiler
859+ using GPUCompiler: AbstractCompilerParams, CompilerJob, NativeCompilerTarget
860+
861+ Base. Experimental. @MethodTable (test_method_table)
862+
863+ struct CompilerParams <: AbstractCompilerParams
864+ entry_safepoint:: Bool
865+ method_table
866+
867+ function CompilerParams (entry_safepoint:: Bool = false , method_table= test_method_table)
868+ return new (entry_safepoint, method_table)
869+ end
870+ end
871+
872+ NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams}
873+
874+ function GPUCompiler. method_table (@nospecialize (job:: NativeCompilerJob ))
875+ return job. config. params. method_table
876+ end
877+ function GPUCompiler. can_safepoint (@nospecialize (job:: NativeCompilerJob ))
878+ return job. config. params. entry_safepoint
879+ end
880+
881+ GPUCompiler. can_throw (@nospecialize (job:: NativeCompilerJob )) = true
882+ GPUCompiler. needs_byval (@nospecialize (job:: NativeCompilerJob )) = false
883+
884+ function GPUCompiler. optimize! (
885+ @nospecialize (job:: NativeCompilerJob ), mod:: GPUCompiler.LLVM.Module ; opt_level
886+ )
887+ return nothing # TODO : add all except GPU stuff passes
888+ end
889+
890+ function create_job (
891+ @nospecialize (func),
892+ @nospecialize (types);
893+ entry_safepoint:: Bool = false ,
894+ method_table= test_method_table,
895+ kwargs... ,
896+ )
897+ config_kwargs, kwargs = split_kwargs (kwargs, GPUCompiler. CONFIG_KWARGS)
898+ source = methodinstance (
899+ typeof (func), Base. to_tuple_type (types), Base. get_world_counter ()
900+ )
901+ target = NativeCompilerTarget ()
902+ params = CompilerParams (entry_safepoint, method_table)
903+ config = CompilerConfig (
904+ target, params; kernel= false , libraries= false , toplevel= true , config_kwargs...
905+ )
906+ return CompilerJob (source, config), kwargs
907+ end
908+
909+ using Enzyme
910+ ReactantInter = Enzyme. Compiler. Interpreter. EnzymeInterpreter{
911+ typeof (Reactant. set_reactant_abi)
912+ }
913+
914+ GPUCompiler. get_interpreter (:: NativeCompilerJob ) = Reactant. ReactantInterpreter ()
915+
916+
917+ function CC. optimize (
918+ interp:: ReactantInter , opt:: CC.OptimizationState , caller:: CC.InferenceResult
919+ )
920+ CC. @timeit " optimizer" ir = CC. run_passes_ipo_safe (opt. src, opt, caller)
921+ CC. ipo_dataflow_analysis! (interp, ir, caller)
922+
923+ mi = caller. linfo
924+ if false && ! (
925+ is_reactant_method (mi) || (
926+ mi. def. sig isa DataType &&
927+ ! should_rewrite_invoke (
928+ mi. def. sig. parameters[1 ], Tuple{mi. def. sig. parameters[2 : end ]. .. }
929+ )
930+ )
931+ )
932+ @info ir
933+ ir, has_changed = rewrite_insts! (ir, interp, false )
934+ @info ir
935+ has_changed && @info " rewrite instruction $mi "
936+ end
937+
938+
939+ return CC. finish (interp, opt, ir, caller)
940+ end
941+
942+ function call_with_reactant (@nospecialize (args... ))
943+ f = args[1 ]
944+ types = typeof .(args[2 : end ])
945+
946+ job, meta = Reactant. create_job (f, types; validate= false )
947+ llvm_module, meta_ = Reactant. JuliaContext () do ctx
948+ GPUCompiler. compile (:llvm , job)
949+ end
950+ mm = meta_. compiled[job. source]
951+ @error mm. ci. def types args
952+ expr = Expr (
953+ :call ,
954+ GlobalRef (Base, :llvmcall ),
955+ (string (llvm_module), mm. specfunc),
956+ mm. ci. rettype,
957+ Tuple{types... },
958+ args[2 : end ]. .. ,
959+ )
960+ # TODO : replace with a generated function
961+ @eval $ expr
962+ end
0 commit comments