499499
500500NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams}
501501
502-
503502GPUCompiler. can_safepoint (@nospecialize (job:: NativeCompilerJob )) = false
504503GPUCompiler. can_throw (@nospecialize (job:: NativeCompilerJob )) = true
505504GPUCompiler. needs_byval (@nospecialize (job:: NativeCompilerJob )) = false
@@ -515,9 +514,12 @@ ReactantInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter{
515514 typeof (Reactant. set_reactant_abi)
516515}
517516
518- GPUCompiler. get_interpreter (@nospecialize (job:: NativeCompilerJob )) = Reactant. ReactantInterpreter (; world = job. world)
519- GPUCompiler. method_table (@nospecialize (job:: NativeCompilerJob )) = CC. method_table (GPUCompiler. get_interpreter (job))
520-
517+ function GPUCompiler. get_interpreter (@nospecialize (job:: NativeCompilerJob ))
518+ return Reactant. ReactantInterpreter (; world= job. world)
519+ end
520+ function GPUCompiler. method_table (@nospecialize (job:: NativeCompilerJob ))
521+ return CC. method_table (GPUCompiler. get_interpreter (job))
522+ end
521523
522524function CC. optimize (
523525 interp:: ReactantInter , opt:: CC.OptimizationState , caller:: CC.InferenceResult
@@ -526,14 +528,16 @@ function CC.optimize(
526528 CC. ipo_dataflow_analysis! (interp, ir, caller)
527529
528530 mi = caller. linfo
529- if false && mi in mi_set && ! (
530- is_reactant_method (mi) || (
531- mi. def. sig isa DataType &&
532- ! should_rewrite_invoke (
533- mi. def. sig. parameters[1 ], Tuple{mi. def. sig. parameters[2 : end ]. .. }
531+ if false &&
532+ mi in mi_set &&
533+ ! (
534+ is_reactant_method (mi) || (
535+ mi. def. sig isa DataType &&
536+ ! should_rewrite_invoke (
537+ mi. def. sig. parameters[1 ], Tuple{mi. def. sig. parameters[2 : end ]. .. }
538+ )
534539 )
535540 )
536- )
537541 @info ir
538542 ir, has_changed = rewrite_insts! (ir, interp, false )
539543 @info ir
546550using GPUCompiler
547551CC = Core. Compiler
548552
549- function GPUCompiler. ci_cache_populate (interp:: Reactant.ReactantInter , cache:: CC.WorldView{CC.InternalCodeCache} , mi:: Core.MethodInstance , min_world:: UInt64 , max_world:: UInt64 )
553+ function GPUCompiler. ci_cache_populate (
554+ interp:: Reactant.ReactantInter ,
555+ cache:: CC.WorldView{CC.InternalCodeCache} ,
556+ mi:: Core.MethodInstance ,
557+ min_world:: UInt64 ,
558+ max_world:: UInt64 ,
559+ )
550560 @warn mi min_world max_world CC. get_inference_world (interp)
551- @invoke GPUCompiler. ci_cache_populate (interp:: CC.AbstractInterpreter , cache, mi, min_world, max_world)
561+ @invoke GPUCompiler. ci_cache_populate (
562+ interp:: CC.AbstractInterpreter , cache, mi, min_world, max_world
563+ )
552564end
553565
554566# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter
@@ -577,7 +589,8 @@ function call_with_reactant_generator(
577589 fn = args[1 + offset_error]
578590
579591 if fn <: Core.Builtin
580- builtin_error = :(throw (AssertionError (" Unsupported call_with_reactant of builtin $fn " )))
592+ builtin_error =
593+ :(throw (AssertionError (" Unsupported call_with_reactant of builtin $fn " )))
581594 return stub (world, source, builtin_error)
582595 end
583596
@@ -591,17 +604,27 @@ function call_with_reactant_generator(
591604 rt = Union{}
592605 end
593606
594- source = GPUCompiler. methodinstance (fn, Base. to_tuple_type (args[2 + offset_error: end ]), world)
607+ source = GPUCompiler. methodinstance (
608+ fn, Base. to_tuple_type (args[(2 + offset_error): end ]), world
609+ )
595610 if source === nothing
596611 method_error = :(throw (
597- MethodError ($ REDUB_ARGUMENTS_NAME[1 + offset_error], $ REDUB_ARGUMENTS_NAME[2 + offset_error: end ], $ world)
612+ MethodError (
613+ $ REDUB_ARGUMENTS_NAME[1 + offset_error],
614+ $ REDUB_ARGUMENTS_NAME[(2 + offset_error): end ],
615+ $ world,
616+ ),
598617 ))
599618 return stub (world, source, method_error)
600- end
619+ end
601620 config = CompilerConfig (
602621 Reactant. NativeCompilerTarget (),
603- Reactant. CompilerParams ()
604- ; kernel= false , libraries= false , toplevel= true , validate= false , strip= true
622+ Reactant. CompilerParams ();
623+ kernel= false ,
624+ libraries= false ,
625+ toplevel= true ,
626+ validate= false ,
627+ strip= true ,
605628 )
606629
607630 job = GPUCompiler. CompilerJob (source, config, world)
@@ -612,11 +635,11 @@ function call_with_reactant_generator(
612635 mm = meta_. compiled[job. source]
613636 @warn typeof (mm)
614637 code_instance = mm. ci
615-
638+
616639 # CodeInfo placehold
617640 code_info = begin
618641 ir = CC. IRCode ()
619- src = ccall (:jl_new_code_info_uninit , Ref{CC. CodeInfo}, ());
642+ src = ccall (:jl_new_code_info_uninit , Ref{CC. CodeInfo}, ())
620643 src. slotnames = fill (:none , length (ir. argtypes) + 1 )
621644 src. slotflags = fill (zero (UInt8), length (ir. argtypes))
622645 src. slottypes = copy (ir. argtypes)
@@ -701,14 +724,16 @@ function call_with_reactant_generator(
701724 )
702725 end
703726
704- push_inst! (Expr (
705- :call ,
706- GlobalRef (Base, :llvmcall ),
707- (string (llvm_module), mm. specfunc),
708- rt,
709- Tuple{args[2 : end ]. .. },
710- fn_args... ,
711- ))
727+ push_inst! (
728+ Expr (
729+ :call ,
730+ GlobalRef (Base, :llvmcall ),
731+ (string (llvm_module), mm. specfunc),
732+ rt,
733+ Tuple{args[2 : end ]. .. },
734+ fn_args... ,
735+ ),
736+ )
712737
713738 push_inst! (Core. ReturnNode (Core. SSAValue (length (overdubbed_code))))
714739
730755 $ (Expr (:meta , :generated_only ))
731756 return $ (Expr (:meta , :generated , call_with_reactant_generator))
732757end
733-
734-
735-
736-
737-
0 commit comments