3939
4040function call_with_reactant end
4141
42- function rewrite_inst (inst)
43- @show inst
44- if Meta. isexpr (inst, :call )
45- rep = Expr (:call , call_with_reactant, inst. args... )
46- @show rep
47- return rep
48- end
49- return inst
42+ # generate a LineInfoNode for the current source code location
43+ macro LineInfoNode (method)
44+ Core. LineInfoNode (__module__, method, __source__. file, Int32 (__source__. line), Int32 (0 ))
5045end
5146
52- function call_with_reactant_generator (world:: UInt , source:: LineNumberNode , @nospecialize (F:: Type ), @nospecialize (N:: Int ), self, @nospecialize (f:: Type ), @nospecialize (args))
47+
48+ const REDUB_ARGUMENTS_NAME = gensym (" redub_arguments" )
49+
50+ function call_with_reactant_generator (world:: UInt , source:: LineNumberNode , self, @nospecialize (args))
5351 @nospecialize
54- @show f, args
52+
53+ @show args
5554
56- stub = Core. GeneratedFunctionStub (identity, Core. svec (:methodinstance , :f , :args ), Core. svec ())
55+ stub = Core. GeneratedFunctionStub (identity, Core. svec (:call_with_reactant , REDUB_ARGUMENTS_NAME ), Core. svec ())
5756
5857 # look up the method match
59- method_error = :(throw (MethodError (f, args, $ world)))
58+ builtin_error = :(throw (AssertionError (" Unsupported call_with_reactant of builtin $args " )))
59+
60+ if args[1 ] <: Core.Builtin
61+ return stub (world, source, builtin_error)
62+ end
63+
64+ method_error = :(throw (MethodError (args[1 ], args[2 : end ], $ world)))
6065
6166 interp = ReactantInterpreter (; world)
6267
63- mt = interp. method_table
68+ sig = Tuple{args... }
69+ lookup_result = Core. Compiler. findall (sig, Core. Compiler. method_table (interp)). matches
70+
71+ if lookup_result === nothing || lookup_result === missing
72+ return stub (world, source, method_error)
73+ end
74+
75+ matches = lookup_result. matches
6476
65- sig = Tuple{F, args... }
66- min_world = Ref {UInt} (typemin (UInt))
67- max_world = Ref {UInt} (typemax (UInt))
68- match = ccall (:jl_gf_invoke_lookup_worlds , Any,
69- (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
70- sig, mt, world, min_world, max_world)
71- match === nothing && return stub (world, source, method_error)
77+ if length (matches) != 1
78+ return stub (world, source, method_error)
79+ end
7280
81+ match = matches[1 ]:: Core.MethodMatch
82+
7383 # look up the method and code instance
7484 mi = ccall (:jl_specializations_get_linfo , Ref{Core. MethodInstance},
7585 (Any, Any, Any), match. method, match. spec_types, match. sparams)
76-
86+
7787 result = Core. Compiler. InferenceResult (mi, Core. Compiler. typeinf_lattice (interp))
78- frame = Core. Compiler. InferenceState (result, #= cache_mode=# :global , interp)
88+ src = Core. Compiler. retrieve_code_info (mi, world)
89+
90+ # prepare a new code info
91+ code_info = copy (src)
92+ method = match. method
93+ static_params = match. sparams
94+ signature = sig
95+ is_invoke = args[1 ] === typeof (Core. invoke)
96+
97+ # propagate edge metadata
98+ code_info. edges = Core. MethodInstance[mi]
99+ code_info. min_world = lookup_result. valid_worlds. min_world
100+ code_info. max_world = lookup_result. valid_worlds. max_world
101+
102+ code_info. slotnames = Any[:call_with_reactant , REDUB_ARGUMENTS_NAME, code_info. slotnames... ]
103+ code_info. slotflags = UInt8[0x00 , 0x00 , code_info. slotflags... ]
104+ # code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] #code_info.slotnames...]
105+ # code_info.slotflags = UInt8[0x00, 0x00] # code_info.slotflags...]
106+ n_prepended_slots = 2
107+ overdub_args_slot = Core. SlotNumber (n_prepended_slots)
108+
109+ # For the sake of convenience, the rest of this pass will translate `code_info`'s fields
110+ # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at
111+ # the end of the pass, we'll reset `code_info` fields accordingly.
112+ overdubbed_code = Any[]
113+ overdubbed_codelocs = Int32[]
114+
115+ # destructure the generated argument slots into the overdubbed method's argument slots.
116+ n_actual_args = fieldcount (signature)
117+ n_method_args = Int (method. nargs)
118+ offset = 1
119+ fn_args = Any[]
120+ for i in 1 : n_method_args
121+ if is_invoke && (i == 1 || i == 2 )
122+ # With an invoke call, we have: 1 is invoke, 2 is f, 3 is Tuple{}, 4... is args.
123+ # In the first loop iteration, we should skip invoke and process f.
124+ # In the second loop iteration, we should skip the Tuple type and process args[1].
125+ offset += 1
126+ end
127+ slot = i + n_prepended_slots
128+ actual_argument = Expr (:call , Core. GlobalRef (Core, :getfield ), overdub_args_slot, offset)
129+ push! (overdubbed_code, :($ (Core. SlotNumber (slot)) = $ actual_argument))
130+ push! (overdubbed_codelocs, code_info. codelocs[1 ])
131+ code_info. slotflags[slot] |= 0x02 # ensure this slotflag has the "assigned" bit set
132+ offset += 1
133+
134+ # push!(overdubbed_code, actual_argument)
135+ push! (fn_args, Core. SSAValue (length (overdubbed_code)))
136+ end
137+
138+ # If `method` is a varargs method, we have to restructure the original method call's
139+ # trailing arguments into a tuple and assign that tuple to the expected argument slot.
140+ if method. isva
141+ if ! isempty (overdubbed_code)
142+ # remove the final slot reassignment leftover from the previous destructuring
143+ pop! (overdubbed_code)
144+ pop! (overdubbed_codelocs)
145+ pop! (fn_args)
146+ end
147+ trailing_arguments = Expr (:call , Core. GlobalRef (Core, :tuple ))
148+ for i in n_method_args: n_actual_args
149+ push! (overdubbed_code, Expr (:call , Core. GlobalRef (Core, :getfield ), overdub_args_slot, offset - 1 ))
150+ push! (overdubbed_codelocs, code_info. codelocs[1 ])
151+ push! (trailing_arguments. args, Core. SSAValue (length (overdubbed_code)))
152+ offset += 1
153+ end
154+ push! (overdubbed_code, Expr (:(= ), Core. SlotNumber (n_method_args + n_prepended_slots), trailing_arguments))
155+ push! (overdubbed_codelocs, code_info. codelocs[1 ])
156+ push! (fn_args, Core. SSAValue (length (overdubbed_code)))
157+ end
158+
159+ #= == finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===#
160+
161+ # substitute static parameters, offset slot numbers by number of added slots, and
162+ # offset statement indices by the number of additional statements
163+ @show code_info. code
164+
165+ @show n_prepended_slots
166+ Base. Meta. partially_inline! (code_info. code, fn_args, method. sig, Any[static_params... ],
167+ n_prepended_slots, length (overdubbed_code), :propagate )
168+ @show code_info. code
169+
170+ # callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...)
171+ # push!(overdubbed_code, callexpr)
172+ # push!(overdubbed_codelocs, code_info.codelocs[1])
173+
174+ # push!(new_ci.code, Core.Compiler.ReturnNode(Core.SSAValue(length(overdubbed_code))))
175+ # push!(overdubbed_codelocs, code_info.codelocs[1])
176+
177+ # original_code_start_index = length(overdubbed_code) + 1
178+
179+ append! (overdubbed_code, code_info. code)
180+ append! (overdubbed_codelocs, code_info. codelocs)
181+
182+ @show overdubbed_code
183+
184+ for i in eachindex (overdubbed_code)
185+ prev = overdubbed_code[i]
186+ if Base. Meta. isexpr (prev, :call )
187+ @show prev
188+ @show prev. args[1 ]
189+ @show prev. args[1 ] isa Core. IntrinsicFunction
190+ if ! (prev. args[1 ] isa Core. IntrinsicFunction)
191+ overdubbed_code[i] = Expr (:call , GlobalRef (Reactant, :call_with_reactant ), prev. args... )
192+ @show " post" , overdubbed_code[i]
193+ end
194+ end
195+ end
196+
197+ #= == set `code_info`/`reflection` fields accordingly ===#
198+
199+ if code_info. method_for_inference_limit_heuristics === nothing
200+ code_info. method_for_inference_limit_heuristics = method
201+ end
202+
203+ code_info. code = overdubbed_code
204+ code_info. codelocs = overdubbed_codelocs
205+ code_info. ssavaluetypes = length (overdubbed_code)
206+ code_info. ssaflags = [0x00 for _ in 1 : length (overdubbed_code)] # XXX we need to copy flags that are set for the original code
207+ self_result = Core. Compiler. InferenceResult (self_mi, Core. Compiler. typeinf_lattice (interp))
208+
209+ @show code_info
210+
211+ @show self
212+ self_meths = Base. _methods_by_ftype (Tuple{self, Vararg{Any}}, - 1 , world)
213+ @show self_meths
214+ self_method = (self_meths[1 ]:: Core.MethodMatch ). method
215+ self_mi = Core. Compiler. specialize_method (self_method, Tuple{typeof (Reactant. call_with_reactant), sig. parameters... }, Core. svec ())
216+ @show self_mi
217+ self_result = Core. Compiler. InferenceResult (self_mi, Core. Compiler. typeinf_lattice (interp))
218+ frame = Core. Compiler. InferenceState (self_result, code_info, #= cache_mode=# :global , interp)
79219 @assert frame != = nothing
80220 Core. Compiler. typeinf (interp, frame)
81221 @assert Core. Compiler. is_inferred (frame)
@@ -85,36 +225,37 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, @nosp
85225 # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
86226 # else
87227 opt = Core. Compiler. OptimizationState (frame, interp)
228+
229+ ir = opt. src
230+ @show ir
231+ for (i, stmt) in enumerate (ir. stmts)
232+ @show stmt
233+
234+ end
235+
236+ @show ir
237+
88238 caller = frame. result
89239 @static if VERSION < v " 1.11-"
90- ir = Core. Compiler. run_passes (opt . src , opt, caller)
240+ ir = Core. Compiler. run_passes (ir , opt, caller)
91241 else
92- ir = Core. Compiler. run_passes_ipo_safe (opt . src , opt, caller)
242+ ir = Core. Compiler. run_passes_ipo_safe (ir , opt, caller)
93243 Core. Compiler. ipo_dataflow_analysis! (interp, opt, ir, caller)
94244 end
95- @show ir
96- for (i, inst) in enumerate (ir. stmts)
97- @static if VERSION < v " 1.11"
98- Core. Compiler. setindex! (ir. stmts[i], rewrite_inst (inst[:inst ]), :inst )
99- else
100- Core. Compiler. setindex! (ir. stmts[i], rewrite_inst (inst[:stmt ]), :stmt )
101- end
102- end
103- @show ir
104245 Core. Compiler. finish (interp, opt, ir, caller)
246+
105247 src = Core. Compiler. ir_to_codeinf! (opt)
106248 # end
107249
108- new_ci = copy (src)
109- new_ci. slotnames = Symbol[Symbol (" #self#" ), :f , :args ]
110- new_ci. edges = Core. MethodInstance[mi]
111- new_ci. min_world = min_world[]
112- new_ci. max_world = max_world[]
250+ src = copy (src)
251+ src. ssavaluetypes = length (src. code)
113252
114- return new_ci
253+ @show src
254+
255+ return src
115256end
116257
117- @eval function call_with_reactant (f :: F , args :: Vararg{Any, N} ) where {F, N}
258+ @eval function call_with_reactant ($ REDUB_ARGUMENTS_NAME ... )
118259 $ (Expr (:meta , :generated_only ))
119260 $ (Expr (:meta , :generated , call_with_reactant_generator))
120261end
@@ -214,12 +355,10 @@ function make_mlir_fn(
214355
215356 # TODO replace with `Base.invoke_within` if julia#52964 lands
216357 # TODO fix it for kwargs
217- oc = call_with_reactant # Core.OpaqueClosure(ir)
218-
219358 if f === Reactant. apply
220- oc (f, traced_args[1 ], (traced_args[2 : end ]. .. ,))
359+ call_with_reactant (f, traced_args[1 ], (traced_args[2 : end ]. .. ,))
221360 else
222- oc (f, traced_args... )
361+ call_with_reactant (f, traced_args... )
223362 end
224363 end
225364
0 commit comments