Skip to content

Commit f6f7cc1

Browse files
authored
Merge pull request #135 from JuliaDiff/avi/update
adjust to the upstream irinterp refactoring
2 parents e0b24da + 651e041 commit f6f7cc1

File tree

3 files changed

+59
-72
lines changed

3 files changed

+59
-72
lines changed

src/codegen/forward_demand.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,10 +189,10 @@ Internal method which generates the code for forward mode diffentiation
189189
190190
191191
- `ir` the IR being differnetation
192-
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
192+
- `to_diff`: collection of all SSA values for which the derivative is to be taken,
193193
paired with the order (first deriviative, second derivative etc)
194194
195-
- `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
195+
- `visit_custom!(ir, stmt, order::Int, recurse::Bool)`:
196196
decides if the custom `transform!` should be applied to a `stmt` or not
197197
Default: `false` for all statements
198198
- `transform!(ir, ssa::SSAValue, order::Int)` mutates `ir` to do a custom tranformation.
@@ -289,10 +289,12 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue, Int}};
289289
end
290290

291291

292-
function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
292+
function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::MethodInstance,
293+
to_diff::Vector{Pair{SSAValue, Int}}; kwargs...)
293294
forward_diff_no_inf!(ir, to_diff; kwargs...)
294295

295296
# Step 3: Re-inference
297+
296298
ir = compact!(ir)
297299

298300
extra_reprocess = CC.BitSet()
@@ -302,9 +304,13 @@ function forward_diff!(ir::IRCode, interp, mi::MethodInstance, world, to_diff::V
302304
end
303305
end
304306

305-
interp′ = enable_reinference(interp)
306-
irsv = IRInterpretationState(interp′, ir, mi, world, ir.argtypes[1:mi.def.nargs])
307-
rt = CC._ir_abstract_constant_propagation(interp′, irsv; extra_reprocess)
307+
method_info = CC.MethodInfo(src)
308+
argtypes = ir.argtypes[1:mi.def.nargs]
309+
world = CC.get_world_counter(interp)
310+
irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world)
311+
rt = CC._ir_abstract_constant_propagation(interp, irsv; extra_reprocess)
312+
313+
ir = compact!(ir)
308314

309315
return ir
310316
end

src/stage2/forward.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
66
interp = ADInterpreter(; forward=true, backward=false)
77
match = Base._which(tt)
88
frame = Core.Compiler.typeinf_frame(interp, match.method, match.spec_types, match.sparams, #=run_optimizer=#true)
9+
mi = frame.linfo
910

10-
ir = copy((interp.opt[0][frame.linfo].inferred).ir::IRCode)
11+
src = CC.copy(interp.unopt[0][mi].src)
12+
ir = CC.copy((@atomic :monotonic interp.opt[0][mi].inferred).ir::IRCode)
1113

1214
# Find all Return Nodes
1315
vals = Pair{SSAValue, Int}[]
@@ -43,10 +45,7 @@ function dontuse_nth_order_forward_stage2(tt::Type, order::Int=1)
4345
return insert_node!(ir, SSAValue(1), NewInstruction(Expr(:call, ∂xⁿ{order}(), arg), typeof(∂xⁿ{order}()(1.0))))
4446
end
4547

48+
ir = forward_diff!(interp, ir, src, mi, vals; visit_custom!, transform!)
4649

47-
irsv = CC.IRInterpretationState(interp, ir, frame.linfo, CC.get_world_counter(interp), ir.argtypes[1:frame.linfo.def.nargs])
48-
ir = forward_diff!(ir, interp, frame.linfo, CC.get_world_counter(interp), vals; visit_custom!, transform!)
49-
50-
ir = compact!(ir)
5150
return OpaqueClosure(ir)
5251
end

src/stage2/interpreter.jl

Lines changed: 43 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ function Compiler3.get_codeinstance(graph::ADGraph, cursor::ADCursor)
3131
end
3232
=#
3333

34-
using Core.Compiler: AbstractInterpreter, NativeInterpreter, InferenceState,
35-
InferenceResult, CodeInstance, WorldRange, ArgInfo, StmtInfo
34+
using Core: MethodInstance, CodeInstance
35+
using .CC: AbstractInterpreter, ArgInfo, Effects, InferenceResult, InferenceState,
36+
IRInterpretationState, NativeInterpreter, OptimizationState, StmtInfo, WorldRange
3637

3738
const OptCache = Dict{MethodInstance, CodeInstance}
3839
const UnoptCache = Dict{Union{MethodInstance, InferenceResult}, Cthulhu.InferredSource}
@@ -42,7 +43,6 @@ struct ADInterpreter <: AbstractInterpreter
4243
# Modes settings
4344
forward::Bool
4445
backward::Bool
45-
reinference::Bool
4646

4747
# This cache is stratified by AD nesting level. Depending on the
4848
# nesting level of the derivative, The AD primitives may behave
@@ -63,7 +63,6 @@ struct ADInterpreter <: AbstractInterpreter
6363
return new(
6464
#=forward::Bool=#false,
6565
#=backward::Bool=#true,
66-
#=reinference::Bool=#false,
6766
#=opt::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
6867
#=unopt::Union{OffsetVector{UnoptCache},Nothing}=#OffsetVector([UnoptCache(), UnoptCache()], 0:1),
6968
#=transformed::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
@@ -74,14 +73,13 @@ struct ADInterpreter <: AbstractInterpreter
7473
function ADInterpreter(interp::ADInterpreter = _ADInterpreter();
7574
forward::Bool = interp.forward,
7675
backward::Bool = interp.backward,
77-
reinference::Bool = interp.reinference,
7876
opt::OffsetVector{OptCache} = interp.opt,
7977
unopt::Union{OffsetVector{UnoptCache},Nothing} = interp.unopt,
8078
transformed::OffsetVector{OptCache} = interp.transformed,
8179
native_interpreter::NativeInterpreter = interp.native_interpreter,
8280
current_level::Int = interp.current_level,
8381
remarks::OffsetVector{RemarksCache} = interp.remarks)
84-
return new(forward, backward, reinference, opt, unopt, transformed, native_interpreter, current_level, remarks)
82+
return new(forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks)
8583
end
8684
end
8785

@@ -90,8 +88,6 @@ raise_level(interp::ADInterpreter) = change_level(interp, interp.current_level +
9088
lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - 1)
9189

9290
disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false)
93-
disable_reinference(interp::ADInterpreter) = ADInterpreter(interp; reinference=false)
94-
enable_reinference(interp::ADInterpreter) = ADInterpreter(interp; reinference=true)
9591

9692
function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor)
9793
@show curs
@@ -120,7 +116,7 @@ function Cthulhu.lookup(interp::ADInterpreter, curs::ADCursor, optimize::Bool; a
120116
opt = codeinst.inferred
121117
if opt !== nothing
122118
opt = opt::Cthulhu.OptimizedSource
123-
src = Core.Compiler.copy(opt.ir)
119+
src = CC.copy(opt.ir)
124120
codeinf = opt.src
125121
infos = src.stmts.info
126122
slottypes = src.argtypes
@@ -162,7 +158,6 @@ function Cthulhu.custom_toggles(interp::ADInterpreter)
162158
end
163159

164160
# TODO: Something is going very wrong here
165-
using Core.Compiler: Effects, OptimizationState
166161
function Cthulhu.get_effects(interp::ADInterpreter, mi::MethodInstance, opt::Bool)
167162
if haskey(interp.unopt[0], mi)
168163
return interp.unopt[0][mi].effects
@@ -171,7 +166,7 @@ function Cthulhu.get_effects(interp::ADInterpreter, mi::MethodInstance, opt::Boo
171166
end
172167
end
173168

174-
function Core.Compiler.is_same_frame(interp::ADInterpreter, linfo::MethodInstance, frame::InferenceState)
169+
function CC.is_same_frame(interp::ADInterpreter, linfo::MethodInstance, frame::InferenceState)
175170
linfo === frame.linfo || return false
176171
return interp.current_level === frame.interp.current_level
177172
end
@@ -224,7 +219,7 @@ function Cthulhu.navigate(curs::ADCursor, callsite::Cthulhu.Callsite)
224219
return ADCursor(curs.level, Cthulhu.get_mi(callsite))
225220
end
226221

227-
function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Compiler.CallInfo), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool)
222+
function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::CC.CallInfo), argtypes::Cthulhu.ArgTypes, @nospecialize(rt), optimize::Bool)
228223
if isa(info, RecurseInfo)
229224
newargtypes = argtypes[2:end]
230225
callinfos = Cthulhu.process_info(interp, info.info, newargtypes, Cthulhu.unwrapType(widenconst(rt)), optimize)
@@ -252,33 +247,33 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Co
252247
elseif isa(info, CompClosInfo)
253248
return Any[CompClosCallInfo(rt)]
254249
end
255-
return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, Core.Compiler.CallInfo, Cthulhu.ArgTypes, Any, Bool},
250+
return invoke(Cthulhu.process_info, Tuple{AbstractInterpreter, CC.CallInfo, Cthulhu.ArgTypes, Any, Bool},
256251
interp, info, argtypes, rt, optimize)
257252
end
258253

259-
Core.Compiler.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter)
260-
Core.Compiler.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter)
261-
Core.Compiler.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter)
262-
Core.Compiler.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter)
254+
CC.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter)
255+
CC.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter)
256+
CC.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter)
257+
CC.get_inference_cache(ei::ADInterpreter) = get_inference_cache(ei.native_interpreter)
263258

264259
# No need to do any locking since we're not putting our results into the runtime cache
265-
Core.Compiler.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
266-
Core.Compiler.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
260+
CC.lock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
261+
CC.unlock_mi_inference(ei::ADInterpreter, mi::MethodInstance) = nothing
267262

268263
struct CodeInfoView
269264
d::Dict{MethodInstance, Any}
270265
end
271266

272-
function Core.Compiler.code_cache(ei::ADInterpreter)
267+
function CC.code_cache(ei::ADInterpreter)
273268
while ei.current_level > lastindex(ei.opt)
274269
push!(ei.opt, Dict{MethodInstance, Any}())
275270
end
276271
ei.opt[ei.current_level]
277272
end
278-
Core.Compiler.may_optimize(ei::ADInterpreter) = true
279-
Core.Compiler.may_compress(ei::ADInterpreter) = false
280-
Core.Compiler.may_discard_trees(ei::ADInterpreter) = false
281-
function Core.Compiler.get(view::CodeInfoView, mi::MethodInstance, default)
273+
CC.may_optimize(ei::ADInterpreter) = true
274+
CC.may_compress(ei::ADInterpreter) = false
275+
CC.may_discard_trees(ei::ADInterpreter) = false
276+
function CC.get(view::CodeInfoView, mi::MethodInstance, default)
282277
r = get(view.d, mi, nothing)
283278
if r === nothing
284279
return default
@@ -298,23 +293,23 @@ end
298293
Cthulhu.get_remarks(interp::ADInterpreter, key::Union{MethodInstance,InferenceResult}) = get(interp.remarks[interp.current_level], key, nothing)
299294

300295
#=
301-
function Core.Compiler.const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
296+
function CC.const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
302297
return true
303298
end
304299
=#
305300

306-
function Core.Compiler.finish(state::InferenceState, interp::ADInterpreter)
307-
res = @invoke Core.Compiler.finish(state::InferenceState, interp::AbstractInterpreter)
308-
key = Core.Compiler.any(state.result.overridden_by_const) ? state.result : state.linfo
301+
function CC.finish(state::InferenceState, interp::ADInterpreter)
302+
res = @invoke CC.finish(state::InferenceState, interp::AbstractInterpreter)
303+
key = CC.any(state.result.overridden_by_const) ? state.result : state.linfo
309304
interp.unopt[interp.current_level][key] = Cthulhu.InferredSource(
310305
copy(state.src),
311306
copy(state.stmt_info),
312-
isdefined(Core.Compiler, :Effects) ? state.ipo_effects : nothing,
307+
state.ipo_effects,
313308
state.result.result)
314309
return res
315310
end
316311

317-
function Core.Compiler.transform_result_for_cache(interp::ADInterpreter,
312+
function CC.transform_result_for_cache(interp::ADInterpreter,
318313
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
319314
return Cthulhu.create_cthulhu_source(result.src, result.ipo_effects)
320315
end
@@ -325,75 +320,62 @@ function CC.inlining_policy(interp::ADInterpreter,
325320
if isa(info, FRuleCallInfo)
326321
return nothing
327322
end
328-
if isdefined(CC, :SemiConcreteResult) && isa(src, CC.SemiConcreteResult)
323+
if isa(src, CC.SemiConcreteResult)
329324
return src
330325
end
331326
@assert isa(src, Cthulhu.OptimizedSource) || isnothing(src)
332327
if isa(src, Cthulhu.OptimizedSource)
333328
if CC.is_stmt_inline(stmt_flag) || src.isinlineable
334329
return src.ir
335330
end
336-
else
337-
# the default inlining policy may try additional effor to find the source in a local cache
338-
return @invoke CC.inlining_policy(interp::AbstractInterpreter,
339-
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
331+
return nothing
340332
end
341-
return nothing
333+
# the default inlining policy may try additional effor to find the source in a local cache
334+
return @invoke CC.inlining_policy(interp::AbstractInterpreter,
335+
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
342336
end
343337

344-
function dummy() end
345-
const dummym = first(methods(dummy))
346-
338+
# TODO remove this overload once https://github.com/JuliaLang/julia/pull/49191 gets merged
347339
function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
348340
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
349-
sv::IRCode, max_methods::Int)
350-
351-
if interp.reinference
352-
# Create a dummy inference state to serve as the root
353-
# TODO: This is terrible - how can we refactor this to do better?
354-
mi = CC.specialize_method(dummym, Tuple{typeof(dummy)}, Core.svec())
355-
result = InferenceResult(mi)
356-
interp′ = disable_forward(disable_reinference(interp))
357-
sv′ = InferenceState(result, :no, interp′)
358-
r = abstract_call_gf_by_type(interp′, f, arginfo, si, atype, sv′, -1)
359-
return r
360-
end
361-
362-
return CallMeta(Any, CC.Effects(), CC.NoCallInfo())
341+
sv::IRInterpretationState, max_methods::Int)
342+
return @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any,
343+
arginfo::ArgInfo, si::StmtInfo, atype::Any,
344+
sv::CC.AbsIntState, max_methods::Int)
363345
end
364346

365347
#=
366-
function Core.Compiler.optimize(interp::ADInterpreter, opt::OptimizationState,
348+
function CC.optimize(interp::ADInterpreter, opt::OptimizationState,
367349
params::OptimizationParams, caller::InferenceResult)
368350
369351
# TODO: Enable some amount of inlining
370352
#@timeit "optimizer" ir = run_passes(opt.src, opt, caller)
371353
372354
sv = opt
373355
ci = opt.src
374-
ir = Core.Compiler.convert_to_ircode(ci, sv)
375-
ir = Core.Compiler.slot2reg(ir, ci, sv)
356+
ir = CC.convert_to_ircode(ci, sv)
357+
ir = CC.slot2reg(ir, ci, sv)
376358
# TODO: Domsorting can produce an updated domtree - no need to recompute here
377-
ir = Core.Compiler.compact!(ir)
378-
return Core.Compiler.finish(interp, opt, params, ir, caller)
359+
ir = CC.compact!(ir)
360+
return CC.finish(interp, opt, params, ir, caller)
379361
end
380362
=#
381363

382-
function Core.Compiler.finish!(interp::ADInterpreter, caller::InferenceResult)
364+
function CC.finish!(interp::ADInterpreter, caller::InferenceResult)
383365
effects = caller.ipo_effects
384366
caller.src = Cthulhu.create_cthulhu_source(caller.src, effects)
385367
end
386368

387369
function ir2codeinst(ir::IRCode, inst::CodeInstance, ci::CodeInfo)
388370
CodeInstance(inst.def, inst.rettype, isdefined(inst, :rettype_const) ? inst.rettype_const : nothing,
389-
Cthulhu.OptimizedSource(Core.Compiler.copy(ir), ci, inst.inferred.isinlineable, Core.Compiler.decode_effects(inst.purity_bits)),
371+
Cthulhu.OptimizedSource(CC.copy(ir), ci, inst.inferred.isinlineable, CC.decode_effects(inst.purity_bits)),
390372
Int32(0), inst.min_world, inst.max_world, inst.ipo_purity_bits, inst.purity_bits,
391373
inst.argescapes, inst.relocatability)
392374
end
393375

394376
using Core: OpaqueClosure
395377
function codegen(interp::ADInterpreter, curs::ADCursor, cache=Dict{ADCursor, OpaqueClosure}())
396-
ir = Core.Compiler.copy(Cthulhu.get_optimized_codeinst(interp, curs).inferred.ir)
378+
ir = CC.copy(Cthulhu.get_optimized_codeinst(interp, curs).inferred.ir)
397379
codeinst = interp.opt[curs.level][curs.mi]
398380
ci = codeinst.inferred.src
399381
if curs.level >= 1

0 commit comments

Comments
 (0)