Skip to content

Commit 411a5f9

Browse files
authored
tidy up ADInterpreter constructor (#132)
1 parent 16634c1 commit 411a5f9

File tree

1 file changed

+38
-20
lines changed

1 file changed

+38
-20
lines changed

src/stage2/interpreter.jl

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ end
3434
using Core.Compiler: AbstractInterpreter, NativeInterpreter, InferenceState,
3535
InferenceResult, CodeInstance, WorldRange, ArgInfo, StmtInfo
3636

37+
const OptCache = Dict{MethodInstance, CodeInstance}
38+
const UnoptCache = Dict{Union{MethodInstance, InferenceResult}, Cthulhu.InferredSource}
39+
const RemarksCache = Dict{Union{MethodInstance,InferenceResult}, Cthulhu.PC2Remarks}
40+
3741
struct ADInterpreter <: AbstractInterpreter
3842
# Modes settings
3943
forward::Bool
@@ -47,21 +51,47 @@ struct ADInterpreter <: AbstractInterpreter
4751
# Level 1 == Gradients
4852
# Level 2 == Seconds Derivatives
4953
# and so on
50-
opt::OffsetVector{Dict{MethodInstance, CodeInstance}}
51-
unopt::Union{OffsetVector{Dict{Union{MethodInstance, InferenceResult}, Cthulhu.InferredSource}}, Nothing}
52-
transformed::OffsetVector{Dict{MethodInstance, CodeInstance}}
54+
opt::OffsetVector{OptCache}
55+
unopt::Union{OffsetVector{UnoptCache},Nothing}
56+
transformed::OffsetVector{OptCache}
5357

5458
native_interpreter::NativeInterpreter
5559
current_level::Int
56-
remarks::OffsetVector{Dict{Union{MethodInstance,InferenceResult}, Cthulhu.PC2Remarks}}
60+
remarks::OffsetVector{RemarksCache}
61+
62+
function _ADInterpreter()
63+
return new(
64+
#=forward::Bool=#false,
65+
#=backward::Bool=#true,
66+
#=reinference::Bool=#false,
67+
#=opt::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
68+
#=unopt::Union{OffsetVector{UnoptCache},Nothing}=#OffsetVector([UnoptCache(), UnoptCache()], 0:1),
69+
#=transformed::OffsetVector{OptCache}=#OffsetVector([OptCache(), OptCache()], 0:1),
70+
#=native_interpreter::NativeInterpreter=#NativeInterpreter(),
71+
#=current_level::Int=#0,
72+
#=remarks::OffsetVector{RemarksCache}=#OffsetVector([RemarksCache()], 0:0))
73+
end
74+
function ADInterpreter(interp::ADInterpreter = _ADInterpreter();
75+
forward::Bool = interp.forward,
76+
backward::Bool = interp.backward,
77+
reinference::Bool = interp.reinference,
78+
opt::OffsetVector{OptCache} = interp.opt,
79+
unopt::Union{OffsetVector{UnoptCache},Nothing} = interp.unopt,
80+
transformed::OffsetVector{OptCache} = interp.transformed,
81+
native_interpreter::NativeInterpreter = interp.native_interpreter,
82+
current_level::Int = interp.current_level,
83+
remarks::OffsetVector{RemarksCache} = interp.remarks)
84+
return new(forward, backward, reinference, opt, unopt, transformed, native_interpreter, current_level, remarks)
85+
end
5786
end
58-
change_level(interp::ADInterpreter, new_level::Int) = ADInterpreter(interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, new_level, interp.remarks)
87+
88+
change_level(interp::ADInterpreter, new_level::Int) = ADInterpreter(interp; current_level=new_level)
5989
raise_level(interp::ADInterpreter) = change_level(interp, interp.current_level + 1)
6090
lower_level(interp::ADInterpreter) = change_level(interp, interp.current_level - 1)
6191

62-
disable_forward(interp::ADInterpreter) = ADInterpreter(false, interp.backward, interp.reinference, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.remarks)
63-
disable_reinference(interp::ADInterpreter) = ADInterpreter(interp.forward, interp.backward, false, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.remarks)
64-
enable_reinference(interp::ADInterpreter) = ADInterpreter(interp.forward, interp.backward, true, interp.opt, interp.unopt, interp.transformed, interp.native_interpreter, interp.current_level, interp.remarks)
92+
disable_forward(interp::ADInterpreter) = ADInterpreter(interp; forward=false)
93+
disable_reinference(interp::ADInterpreter) = ADInterpreter(interp; reinference=false)
94+
enable_reinference(interp::ADInterpreter) = ADInterpreter(interp; reinference=true)
6595

6696
function Cthulhu.get_optimized_codeinst(interp::ADInterpreter, curs::ADCursor)
6797
@show curs
@@ -226,18 +256,6 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Co
226256
interp, info, argtypes, rt, optimize)
227257
end
228258

229-
ADInterpreter(;forward = false, backward=true, reinference=false) = ADInterpreter(forward, backward, reinference,
230-
OffsetVector([Dict{MethodInstance, CodeInstance}(), Dict{MethodInstance, CodeInstance}()], 0:1),
231-
OffsetVector([Dict{MethodInstance, Cthulhu.InferredSource}(), Dict{MethodInstance, Cthulhu.InferredSource}()], 0:1),
232-
OffsetVector([Dict{MethodInstance, CodeInstance}(), Dict{MethodInstance, CodeInstance}()], 0:1),
233-
NativeInterpreter(),
234-
0,
235-
OffsetVector([Dict{Union{MethodInstance,InferenceResult}, Cthulhu.PC2Remarks}()], 0:0)
236-
)
237-
238-
ADInterpreter(fg::ADGraph, level) =
239-
ADInterpreter(fg.code, NativeInterpreter(), level, fg.msgs)
240-
241259
Core.Compiler.InferenceParams(ei::ADInterpreter) = InferenceParams(ei.native_interpreter)
242260
Core.Compiler.OptimizationParams(ei::ADInterpreter) = OptimizationParams(ei.native_interpreter)
243261
Core.Compiler.get_world_counter(ei::ADInterpreter) = get_world_counter(ei.native_interpreter)

0 commit comments

Comments
 (0)