Skip to content

Commit c56641b

Browse files
committed
remove reinference hack
1 parent a9ae1d2 commit c56641b

File tree

2 files changed

+5
-23
lines changed

2 files changed

+5
-23
lines changed

src/codegen/forward_demand.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ function forward_diff!(interp::ADInterpreter, ir::IRCode, src::CodeInfo, mi::Met
308308
argtypes = ir.argtypes[1:mi.def.nargs]
309309
world = CC.get_world_counter(interp)
310310
irsv = IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world)
311-
rt = CC._ir_abstract_constant_propagation(enable_reinference(interp), irsv; extra_reprocess)
311+
rt = CC._ir_abstract_constant_propagation(interp, irsv; extra_reprocess)
312312

313313
ir = compact!(ir)
314314

src/stage2/interpreter.jl

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ struct ADInterpreter <: AbstractInterpreter
4343
# Modes settings
4444
forward::Bool
4545
backward::Bool
46-
reinference::Bool
4746

4847
# This cache is stratified by AD nesting level. Depending on the
4948
# nesting level of the derivative, The AD primitives may behave
@@ -64,7 +63,6 @@ struct ADInterpreter <: AbstractInterpreter
6463
return new(
6564
#=forward::Bool=#false,
6665
#=backward::Bool=#true,
67-
#=reinference::Bool=#false,
6866
#=opt::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
6967
#=unopt::Union{OffsetVector{UnoptCache},Nothing}=#OffsetVector([UnoptCache(), UnoptCache()], 0:1),
7068
#=transformed::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
@@ -75,14 +73,13 @@ struct ADInterpreter <: AbstractInterpreter
7573
function ADInterpreter(interp::ADInterpreter = _ADInterpreter();
7674
forward::Bool = interp.forward,
7775
backward::Bool = interp.backward,
78-
reinference::Bool = interp.reinference,
7976
opt::OffsetVector{OptCache} = interp.opt,
8077
unopt::Union{OffsetVector{UnoptCache},Nothing} = interp.unopt,
8178
transformed::OffsetVector{OptCache} = interp.transformed,
8279
native_interpreter::NativeInterpreter = interp.native_interpreter,
8380
current_level::Int = interp.current_level,
8481
remarks::OffsetVector{RemarksCache} = interp.remarks)
85-
return new(forward, backward, reinference, opt, unopt, transformed, native_interpreter, current_level, remarks)
82+
return new(forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks)
8683
end
8784
end
8885

@@ -91,8 +88,6 @@ raise_level(interp::ADInterpreter) = change_level(interp, interp.current_level +
9188
lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - 1)
9289

9390
disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false)
94-
disable_reinference(interp::ADInterpreter) = ADInterpreter(interp; reinference=false)
95-
enable_reinference(interp::ADInterpreter) = ADInterpreter(interp; reinference=true)
9691

9792
function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor)
9893
@show curs
@@ -340,25 +335,12 @@ function CC.inlining_policy(interp::ADInterpreter,
340335
nothing, info::CC.CallInfo, stmt_flag::UInt8, mi::MethodInstance, argtypes::Vector{Any})
341336
end
342337

343-
function dummy() end
344-
const dummym = first(methods(dummy))
345-
346338
function CC.abstract_call_gf_by_type(interp::ADInterpreter, @nospecialize(f),
347339
arginfo::ArgInfo, si::StmtInfo, @nospecialize(atype),
348340
sv::IRInterpretationState, max_methods::Int)
349-
350-
if interp.reinference
351-
# Create a dummy inference state to serve as the root
352-
# TODO: This is terrible - how can we refactor this to do better?
353-
mi = CC.specialize_method(dummym, Tuple{typeof(dummy)}, Core.svec())
354-
result = InferenceResult(mi)
355-
interp′ = disable_forward(disable_reinference(interp))
356-
sv′ = InferenceState(result, :no, interp′)
357-
r = abstract_call_gf_by_type(interp′, f, arginfo, si, atype, sv′, -1)
358-
return r
359-
end
360-
361-
return CallMeta(Any, Effects(), CC.NoCallInfo())
341+
return @invoke CC.abstract_call_gf_by_type(interp::AbstractInterpreter, f::Any,
342+
arginfo::ArgInfo, si::StmtInfo, atype::Any,
343+
sv::CC.AbsIntState, max_methods::Int)
362344
end
363345

364346
#=

0 commit comments

Comments
 (0)