34
34
using Core. Compiler: AbstractInterpreter, NativeInterpreter, InferenceState,
35
35
InferenceResult, CodeInstance, WorldRange, ArgInfo, StmtInfo
36
36
37
+ const OptCache = Dict{MethodInstance, CodeInstance}
38
+ const UnoptCache = Dict{Union{MethodInstance, InferenceResult}, Cthulhu. InferredSource}
39
+ const RemarksCache = Dict{Union{MethodInstance,InferenceResult}, Cthulhu. PC2Remarks}
40
+
37
41
struct ADInterpreter <: AbstractInterpreter
38
42
# Modes settings
39
43
forward:: Bool
@@ -47,21 +51,47 @@ struct ADInterpreter <: AbstractInterpreter
47
51
# Level 1 == Gradients
48
52
# Level 2 == Seconds Derivatives
49
53
# and so on
50
- opt:: OffsetVector{Dict{MethodInstance, CodeInstance} }
51
- unopt:: Union {OffsetVector{Dict{Union{MethodInstance, InferenceResult}, Cthulhu . InferredSource}}, Nothing}
52
- transformed:: OffsetVector{Dict{MethodInstance, CodeInstance} }
54
+ opt:: OffsetVector{OptCache }
55
+ unopt:: Union{OffsetVector{UnoptCache}, Nothing}
56
+ transformed:: OffsetVector{OptCache }
53
57
54
58
native_interpreter:: NativeInterpreter
55
59
current_level:: Int
56
- remarks:: OffsetVector{Dict{Union{MethodInstance,InferenceResult}, Cthulhu.PC2Remarks}}
60
+ remarks:: OffsetVector{RemarksCache}
61
+
62
+ function _ADInterpreter ()
63
+ return new (
64
+ #= forward::Bool=# false ,
65
+ #= backward::Bool=# true ,
66
+ #= reinference::Bool=# false ,
67
+ #= opt::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
68
+ #= unopt::Union{OffsetVector{UnoptCache},Nothing}=# OffsetVector ([UnoptCache (), UnoptCache ()], 0 : 1 ),
69
+ #= transformed::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
70
+ #= native_interpreter::NativeInterpreter=# NativeInterpreter (),
71
+ #= current_level::Int=# 0 ,
72
+ #= remarks::OffsetVector{RemarksCache}=# OffsetVector ([RemarksCache ()], 0 : 0 ))
73
+ end
74
+ function ADInterpreter (interp:: ADInterpreter = _ADInterpreter ();
75
+ forward:: Bool = interp. forward,
76
+ backward:: Bool = interp. backward,
77
+ reinference:: Bool = interp. reinference,
78
+ opt:: OffsetVector{OptCache} = interp. opt,
79
+ unopt:: Union{OffsetVector{UnoptCache},Nothing} = interp. unopt,
80
+ transformed:: OffsetVector{OptCache} = interp. transformed,
81
+ native_interpreter:: NativeInterpreter = interp. native_interpreter,
82
+ current_level:: Int = interp. current_level,
83
+ remarks:: OffsetVector{RemarksCache} = interp. remarks)
84
+ return new (forward, backward, reinference, opt, unopt, transformed, native_interpreter, current_level, remarks)
85
+ end
57
86
end
58
- change_level (interp:: ADInterpreter , new_level:: Int ) = ADInterpreter (interp. opt, interp. unopt, interp. transformed, interp. native_interpreter, new_level, interp. remarks)
87
+
88
+ change_level (interp:: ADInterpreter , new_level:: Int ) = ADInterpreter (interp; current_level= new_level)
59
89
raise_level (interp:: ADInterpreter ) = change_level (interp, interp. current_level + 1 )
60
90
lower_level (interp:: ADInterpreter ) = change_level (interp, interp. current_level - 1 )
61
91
62
- disable_forward (interp:: ADInterpreter ) = ADInterpreter (false , interp. backward, interp . reinference, interp . opt, interp . unopt, interp . transformed, interp . native_interpreter, interp . current_level, interp . remarks )
63
- disable_reinference (interp:: ADInterpreter ) = ADInterpreter (interp. forward, interp . backward, false , interp . opt, interp . unopt, interp . transformed, interp . native_interpreter, interp . current_level, interp . remarks )
64
- enable_reinference (interp:: ADInterpreter ) = ADInterpreter (interp. forward, interp . backward, true , interp . opt, interp . unopt, interp . transformed, interp . native_interpreter, interp . current_level, interp . remarks )
92
+ disable_forward (interp:: ADInterpreter ) = ADInterpreter (interp; forward = false )
93
+ disable_reinference (interp:: ADInterpreter ) = ADInterpreter (interp; reinference = false )
94
+ enable_reinference (interp:: ADInterpreter ) = ADInterpreter (interp; reinference = true )
65
95
66
96
function Cthulhu. get_optimized_codeinst (interp:: ADInterpreter , curs:: ADCursor )
67
97
@show curs
@@ -226,18 +256,6 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Co
226
256
interp, info, argtypes, rt, optimize)
227
257
end
228
258
229
- ADInterpreter (;forward = false , backward= true , reinference= false ) = ADInterpreter (forward, backward, reinference,
230
- OffsetVector ([Dict {MethodInstance, CodeInstance} (), Dict {MethodInstance, CodeInstance} ()], 0 : 1 ),
231
- OffsetVector ([Dict {MethodInstance, Cthulhu.InferredSource} (), Dict {MethodInstance, Cthulhu.InferredSource} ()], 0 : 1 ),
232
- OffsetVector ([Dict {MethodInstance, CodeInstance} (), Dict {MethodInstance, CodeInstance} ()], 0 : 1 ),
233
- NativeInterpreter (),
234
- 0 ,
235
- OffsetVector ([Dict {Union{MethodInstance,InferenceResult}, Cthulhu.PC2Remarks} ()], 0 : 0 )
236
- )
237
-
238
- ADInterpreter (fg:: ADGraph , level) =
239
- ADInterpreter (fg. code, NativeInterpreter (), level, fg. msgs)
240
-
241
259
Core. Compiler. InferenceParams (ei:: ADInterpreter ) = InferenceParams (ei. native_interpreter)
242
260
Core. Compiler. OptimizationParams (ei:: ADInterpreter ) = OptimizationParams (ei. native_interpreter)
243
261
Core. Compiler. get_world_counter (ei:: ADInterpreter ) = get_world_counter (ei. native_interpreter)
0 commit comments