@@ -37,6 +37,88 @@ function apply(f, args...; kwargs...)
3737 return f (args... ; kwargs... )
3838end
3939
40+ function call_with_reactant end
41+
42+ function rewrite_inst (inst)
43+ @show inst
44+ if Meta. isexpr (inst, :call )
45+ rep = Expr (:call , call_with_reactant, inst. args... )
46+ @show rep
47+ return rep
48+ end
49+ return inst
50+ end
51+
52+ function call_with_reactant_generator (world:: UInt , source:: LineNumberNode , @nospecialize (F:: Type ), @nospecialize (N:: Int ), self, @nospecialize (f:: Type ), @nospecialize (args))
53+ @nospecialize
54+ @show f, args
55+
56+ stub = Core. GeneratedFunctionStub (identity, Core. svec (:methodinstance , :f , :args ), Core. svec ())
57+
58+ # look up the method match
59+ method_error = :(throw (MethodError (f, args, $ world)))
60+
61+ interp = ReactantInterpreter (; world)
62+
63+ mt = interp. method_table
64+
65+ sig = Tuple{F, args... }
66+ min_world = Ref {UInt} (typemin (UInt))
67+ max_world = Ref {UInt} (typemax (UInt))
68+ match = ccall (:jl_gf_invoke_lookup_worlds , Any,
69+ (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
70+ sig, mt, world, min_world, max_world)
71+ match === nothing && return stub (world, source, method_error)
72+
73+ # look up the method and code instance
74+ mi = ccall (:jl_specializations_get_linfo , Ref{Core. MethodInstance},
75+ (Any, Any, Any), match. method, match. spec_types, match. sparams)
76+
77+ result = Core. Compiler. InferenceResult (mi, Core. Compiler. typeinf_lattice (interp))
78+ frame = Core. Compiler. InferenceState (result, #= cache_mode=# :global , interp)
79+ @assert frame != = nothing
80+ Core. Compiler. typeinf (interp, frame)
81+ @assert Core. Compiler. is_inferred (frame)
82+
83+ # if Core.Compiler.result_is_constabi(interp, frame.result)
84+ # rt = frame.result.result::Core.Compiler.Const
85+ # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
86+ # else
87+ opt = Core. Compiler. OptimizationState (frame, interp)
88+ caller = frame. result
89+ @static if VERSION < v " 1.11-"
90+ ir = Core. Compiler. run_passes (opt. src, opt, caller)
91+ else
92+ ir = Core. Compiler. run_passes_ipo_safe (opt. src, opt, caller)
93+ Core. Compiler. ipo_dataflow_analysis! (interp, opt, ir, caller)
94+ end
95+ @show ir
96+ for (i, inst) in enumerate (ir. stmts)
97+ @static if VERSION < v " 1.11"
98+ Core. Compiler. setindex! (ir. stmts[i], rewrite_inst (inst[:inst ]), :inst )
99+ else
100+ Core. Compiler. setindex! (ir. stmts[i], rewrite_inst (inst[:stmt ]), :stmt )
101+ end
102+ end
103+ @show ir
104+ Core. Compiler. finish (interp, opt, ir, caller)
105+ src = Core. Compiler. ir_to_codeinf! (opt)
106+ # end
107+
108+ new_ci = copy (src)
109+ new_ci. slotnames = Symbol[Symbol (" #self#" ), :f , :args ]
110+ new_ci. edges = Core. MethodInstance[mi]
111+ new_ci. min_world = min_world[]
112+ new_ci. max_world = max_world[]
113+
114+ return new_ci
115+ end
116+
117+ @eval function call_with_reactant (f:: F , args:: Vararg{Any, N} ) where {F, N}
118+ $ (Expr (:meta , :generated_only ))
119+ $ (Expr (:meta , :generated , call_with_reactant_generator))
120+ end
121+
40122function make_mlir_fn (
41123 f,
42124 args,
@@ -131,36 +213,13 @@ function make_mlir_fn(
131213 interp = ReactantInterpreter ()
132214
133215 # TODO replace with `Base.invoke_within` if julia#52964 lands
134- # TODO fix it for kwargs
135- ircoderes = Base. code_ircode (f, map (typeof, traced_args); interp)
136-
137- if length (ircoderes) != 1
138- throw (
139- AssertionError (
140- " Could not find unique ircode for $f $traced_args , found $ircoderes "
141- ),
142- )
143- end
144- ir, ty = ircoderes[1 ]
145- oc = Core. OpaqueClosure (ir)
216+ # TODO fix it for kwargs
217+ oc = call_with_reactant # Core.OpaqueClosure(ir)
146218
147219 if f === Reactant. apply
148- oc (traced_args[1 ], (traced_args[2 : end ]. .. ,))
220+ oc (f, traced_args[1 ], (traced_args[2 : end ]. .. ,))
149221 else
150- if (length (traced_args) + 1 != length (ir. argtypes)) || (
151- length (traced_args) > 0 &&
152- length (ir. argtypes) > 0 &&
153- ! (last (ir. argtypes) isa Core. Const) &&
154- last (ir. argtypes) != typeof (traced_args[end ])
155- )
156- @assert ir. argtypes[end ] <: Tuple
157- oc (
158- traced_args[1 : (length (ir. argtypes) - 2 )]. .. ,
159- (traced_args[(length (ir. argtypes) - 1 ): end ]. .. ,),
160- )
161- else
162- oc (traced_args... )
163- end
222+ oc (f, traced_args... )
164223 end
165224 end
166225
0 commit comments