Skip to content

Commit 3b78329

Browse files
author
William Moses
committed
inf rec
1 parent e0a3ec4 commit 3b78329

File tree

3 files changed

+189
-50
lines changed

3 files changed

+189
-50
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ function compiler_cache(ctx::MLIR.IR.Context)
437437
return cache
438438
end
439439

440-
Reactant.@overlay function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
440+
Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
441441
@show "recufunction", f, tt
442442
res = Base.@lock CUDA.cufunction_lock begin
443443
# compile the function

src/Interpreter.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ import Core.Compiler:
2222
MethodResultPure
2323

2424

25-
Base.Experimental.@MethodTable REACTANT_METHOD_TABLE
25+
Base.Experimental.@MethodTable(REACTANT_METHOD_TABLE)
2626

27-
macro overlay(method_expr)
28-
def = splitdef(method_expr)
29-
def[:name] = Expr(:overlay, :(Reactant.REACTANT_METHOD_TABLE), def[:name])
30-
return esc(combinedef(def))
27+
function var"@reactant_override"(__source__::LineNumberNode, __module__::Module, def)
28+
return Base.Experimental.var"@overlay"(
29+
__source__, __module__, :(Reactant.REACTANT_METHOD_TABLE), def
30+
)
3131
end
3232

3333
function set_reactant_abi(

src/utils.jl

Lines changed: 183 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -39,43 +39,183 @@ end
3939

4040
function call_with_reactant end
4141

42-
function rewrite_inst(inst)
43-
@show inst
44-
if Meta.isexpr(inst, :call)
45-
rep = Expr(:call, call_with_reactant, inst.args...)
46-
@show rep
47-
return rep
48-
end
49-
return inst
42+
# generate a LineInfoNode for the current source code location
43+
macro LineInfoNode(method)
44+
Core.LineInfoNode(__module__, method, __source__.file, Int32(__source__.line), Int32(0))
5045
end
5146

52-
function call_with_reactant_generator(world::UInt, source::LineNumberNode, @nospecialize(F::Type), @nospecialize(N::Int), self, @nospecialize(f::Type), @nospecialize(args))
47+
48+
const REDUB_ARGUMENTS_NAME = gensym("redub_arguments")
49+
50+
function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize(args))
5351
@nospecialize
54-
@show f, args
52+
53+
@show args
5554

56-
stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :f, :args), Core.svec())
55+
stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec())
5756

5857
# look up the method match
59-
method_error = :(throw(MethodError(f, args, $world)))
58+
builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $args")))
59+
60+
if args[1] <: Core.Builtin
61+
return stub(world, source, builtin_error)
62+
end
63+
64+
method_error = :(throw(MethodError(args[1], args[2:end], $world)))
6065

6166
interp = ReactantInterpreter(; world)
6267

63-
mt = interp.method_table
68+
sig = Tuple{args...}
69+
lookup_result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp)).matches
70+
71+
if lookup_result === nothing || lookup_result === missing
72+
return stub(world, source, method_error)
73+
end
74+
75+
matches = lookup_result.matches
6476

65-
sig = Tuple{F, args...}
66-
min_world = Ref{UInt}(typemin(UInt))
67-
max_world = Ref{UInt}(typemax(UInt))
68-
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
69-
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
70-
sig, mt, world, min_world, max_world)
71-
match === nothing && return stub(world, source, method_error)
77+
if length(matches) != 1
78+
return stub(world, source, method_error)
79+
end
7280

81+
match = matches[1]::Core.MethodMatch
82+
7383
# look up the method and code instance
7484
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
7585
(Any, Any, Any), match.method, match.spec_types, match.sparams)
76-
86+
7787
result = Core.Compiler.InferenceResult(mi, Core.Compiler.typeinf_lattice(interp))
78-
frame = Core.Compiler.InferenceState(result, #=cache_mode=#:global, interp)
88+
src = Core.Compiler.retrieve_code_info(mi, world)
89+
90+
# prepare a new code info
91+
code_info = copy(src)
92+
method = match.method
93+
static_params = match.sparams
94+
signature = sig
95+
is_invoke = args[1] === typeof(Core.invoke)
96+
97+
# propagate edge metadata
98+
code_info.edges = Core.MethodInstance[mi]
99+
code_info.min_world = lookup_result.valid_worlds.min_world
100+
code_info.max_world = lookup_result.valid_worlds.max_world
101+
102+
code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...]
103+
code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...]
104+
#code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] #code_info.slotnames...]
105+
#code_info.slotflags = UInt8[0x00, 0x00] # code_info.slotflags...]
106+
n_prepended_slots = 2
107+
overdub_args_slot = Core.SlotNumber(n_prepended_slots)
108+
109+
# For the sake of convenience, the rest of this pass will translate `code_info`'s fields
110+
# into these overdubbed equivalents instead of updating `code_info` in-place. Then, at
111+
# the end of the pass, we'll reset `code_info` fields accordingly.
112+
overdubbed_code = Any[]
113+
overdubbed_codelocs = Int32[]
114+
115+
# destructure the generated argument slots into the overdubbed method's argument slots.
116+
n_actual_args = fieldcount(signature)
117+
n_method_args = Int(method.nargs)
118+
offset = 1
119+
fn_args = Any[]
120+
for i in 1:n_method_args
121+
if is_invoke && (i == 1 || i == 2)
122+
# With an invoke call, we have: 1 is invoke, 2 is f, 3 is Tuple{}, 4... is args.
123+
# In the first loop iteration, we should skip invoke and process f.
124+
# In the second loop iteration, we should skip the Tuple type and process args[1].
125+
offset += 1
126+
end
127+
slot = i + n_prepended_slots
128+
actual_argument = Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset)
129+
push!(overdubbed_code, :($(Core.SlotNumber(slot)) = $actual_argument))
130+
push!(overdubbed_codelocs, code_info.codelocs[1])
131+
code_info.slotflags[slot] |= 0x02 # ensure this slotflag has the "assigned" bit set
132+
offset += 1
133+
134+
#push!(overdubbed_code, actual_argument)
135+
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
136+
end
137+
138+
# If `method` is a varargs method, we have to restructure the original method call's
139+
# trailing arguments into a tuple and assign that tuple to the expected argument slot.
140+
if method.isva
141+
if !isempty(overdubbed_code)
142+
# remove the final slot reassignment leftover from the previous destructuring
143+
pop!(overdubbed_code)
144+
pop!(overdubbed_codelocs)
145+
pop!(fn_args)
146+
end
147+
trailing_arguments = Expr(:call, Core.GlobalRef(Core, :tuple))
148+
for i in n_method_args:n_actual_args
149+
push!(overdubbed_code, Expr(:call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset - 1))
150+
push!(overdubbed_codelocs, code_info.codelocs[1])
151+
push!(trailing_arguments.args, Core.SSAValue(length(overdubbed_code)))
152+
offset += 1
153+
end
154+
push!(overdubbed_code, Expr(:(=), Core.SlotNumber(n_method_args + n_prepended_slots), trailing_arguments))
155+
push!(overdubbed_codelocs, code_info.codelocs[1])
156+
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
157+
end
158+
159+
#=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===#
160+
161+
# substitute static parameters, offset slot numbers by number of added slots, and
162+
# offset statement indices by the number of additional statements
163+
@show code_info.code
164+
165+
@show n_prepended_slots
166+
Base.Meta.partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...],
167+
n_prepended_slots, length(overdubbed_code), :propagate)
168+
@show code_info.code
169+
170+
#callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...)
171+
#push!(overdubbed_code, callexpr)
172+
#push!(overdubbed_codelocs, code_info.codelocs[1])
173+
174+
#push!(new_ci.code, Core.Compiler.ReturnNode(Core.SSAValue(length(overdubbed_code))))
175+
#push!(overdubbed_codelocs, code_info.codelocs[1])
176+
177+
# original_code_start_index = length(overdubbed_code) + 1
178+
179+
append!(overdubbed_code, code_info.code)
180+
append!(overdubbed_codelocs, code_info.codelocs)
181+
182+
@show overdubbed_code
183+
184+
for i in eachindex(overdubbed_code)
185+
prev = overdubbed_code[i]
186+
if Base.Meta.isexpr(prev, :call)
187+
@show prev
188+
@show prev.args[1]
189+
@show prev.args[1] isa Core.IntrinsicFunction
190+
if !(prev.args[1] isa Core.IntrinsicFunction)
191+
overdubbed_code[i] = Expr(:call, GlobalRef(Reactant, :call_with_reactant), prev.args...)
192+
@show "post", overdubbed_code[i]
193+
end
194+
end
195+
end
196+
197+
#=== set `code_info`/`reflection` fields accordingly ===#
198+
199+
if code_info.method_for_inference_limit_heuristics === nothing
200+
code_info.method_for_inference_limit_heuristics = method
201+
end
202+
203+
code_info.code = overdubbed_code
204+
code_info.codelocs = overdubbed_codelocs
205+
code_info.ssavaluetypes = length(overdubbed_code)
206+
code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code
207+
self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp))
208+
209+
@show code_info
210+
211+
@show self
212+
self_meths = Base._methods_by_ftype(Tuple{self, Vararg{Any}}, -1, world)
213+
@show self_meths
214+
self_method = (self_meths[1]::Core.MethodMatch).method
215+
self_mi = Core.Compiler.specialize_method(self_method, Tuple{typeof(Reactant.call_with_reactant), sig.parameters...}, Core.svec())
216+
@show self_mi
217+
self_result = Core.Compiler.InferenceResult(self_mi, Core.Compiler.typeinf_lattice(interp))
218+
frame = Core.Compiler.InferenceState(self_result, code_info, #=cache_mode=#:global, interp)
79219
@assert frame !== nothing
80220
Core.Compiler.typeinf(interp, frame)
81221
@assert Core.Compiler.is_inferred(frame)
@@ -85,36 +225,37 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, @nosp
85225
# src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val)
86226
#else
87227
opt = Core.Compiler.OptimizationState(frame, interp)
228+
229+
ir = opt.src
230+
@show ir
231+
for (i, stmt) in enumerate(ir.stmts)
232+
@show stmt
233+
234+
end
235+
236+
@show ir
237+
88238
caller = frame.result
89239
@static if VERSION < v"1.11-"
90-
ir = Core.Compiler.run_passes(opt.src, opt, caller)
240+
ir = Core.Compiler.run_passes(ir, opt, caller)
91241
else
92-
ir = Core.Compiler.run_passes_ipo_safe(opt.src, opt, caller)
242+
ir = Core.Compiler.run_passes_ipo_safe(ir, opt, caller)
93243
Core.Compiler.ipo_dataflow_analysis!(interp, opt, ir, caller)
94244
end
95-
@show ir
96-
for (i, inst) in enumerate(ir.stmts)
97-
@static if VERSION < v"1.11"
98-
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:inst]), :inst)
99-
else
100-
Core.Compiler.setindex!(ir.stmts[i], rewrite_inst(inst[:stmt]), :stmt)
101-
end
102-
end
103-
@show ir
104245
Core.Compiler.finish(interp, opt, ir, caller)
246+
105247
src = Core.Compiler.ir_to_codeinf!(opt)
106248
#end
107249

108-
new_ci = copy(src)
109-
new_ci.slotnames = Symbol[Symbol("#self#"), :f, :args]
110-
new_ci.edges = Core.MethodInstance[mi]
111-
new_ci.min_world = min_world[]
112-
new_ci.max_world = max_world[]
250+
src = copy(src)
251+
src.ssavaluetypes = length(src.code)
113252

114-
return new_ci
253+
@show src
254+
255+
return src
115256
end
116257

117-
@eval function call_with_reactant(f::F, args::Vararg{Any, N}) where {F, N}
258+
@eval function call_with_reactant($REDUB_ARGUMENTS_NAME...)
118259
$(Expr(:meta, :generated_only))
119260
$(Expr(:meta, :generated, call_with_reactant_generator))
120261
end
@@ -214,12 +355,10 @@ function make_mlir_fn(
214355

215356
# TODO replace with `Base.invoke_within` if julia#52964 lands
216357
# TODO fix it for kwargs
217-
oc = call_with_reactant # Core.OpaqueClosure(ir)
218-
219358
if f === Reactant.apply
220-
oc(f, traced_args[1], (traced_args[2:end]...,))
359+
call_with_reactant(f, traced_args[1], (traced_args[2:end]...,))
221360
else
222-
oc(f, traced_args...)
361+
call_with_reactant(f, traced_args...)
223362
end
224363
end
225364

0 commit comments

Comments
 (0)