@@ -31,8 +31,9 @@ function Compiler3.get_codeinstance(graph::ADGraph, cursor::ADCursor)
31
31
end
32
32
=#
33
33
34
- using Core. Compiler: AbstractInterpreter, NativeInterpreter, InferenceState,
35
- InferenceResult, CodeInstance, WorldRange, ArgInfo, StmtInfo
34
+ using Core: MethodInstance, CodeInstance
35
+ using . CC: AbstractInterpreter, ArgInfo, Effects, InferenceResult, InferenceState,
36
+ IRInterpretationState, NativeInterpreter, OptimizationState, StmtInfo, WorldRange
36
37
37
38
const OptCache = Dict{MethodInstance, CodeInstance}
38
39
const UnoptCache = Dict{Union{MethodInstance, InferenceResult}, Cthulhu. InferredSource}
@@ -42,7 +43,6 @@ struct ADInterpreter <: AbstractInterpreter
42
43
# Modes settings
43
44
forward:: Bool
44
45
backward:: Bool
45
- reinference:: Bool
46
46
47
47
# This cache is stratified by AD nesting level. Depending on the
48
48
# nesting level of the derivative, The AD primitives may behave
@@ -63,7 +63,6 @@ struct ADInterpreter <: AbstractInterpreter
63
63
return new (
64
64
#= forward::Bool=# false ,
65
65
#= backward::Bool=# true ,
66
- #= reinference::Bool=# false ,
67
66
#= opt::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
68
67
#= unopt::Union{OffsetVector{UnoptCache},Nothing}=# OffsetVector ([UnoptCache (), UnoptCache ()], 0 : 1 ),
69
68
#= transformed::OffsetVector{OptCache}=# OffsetVector ([OptCache (), OptCache ()], 0 : 1 ),
@@ -74,14 +73,13 @@ struct ADInterpreter <: AbstractInterpreter
74
73
function ADInterpreter (interp:: ADInterpreter = _ADInterpreter ();
75
74
forward:: Bool = interp. forward,
76
75
backward:: Bool = interp. backward,
77
- reinference:: Bool = interp. reinference,
78
76
opt:: OffsetVector{OptCache} = interp. opt,
79
77
unopt:: Union{OffsetVector{UnoptCache},Nothing} = interp. unopt,
80
78
transformed:: OffsetVector{OptCache} = interp. transformed,
81
79
native_interpreter:: NativeInterpreter = interp. native_interpreter,
82
80
current_level:: Int = interp. current_level,
83
81
remarks:: OffsetVector{RemarksCache} = interp. remarks)
84
- return new (forward, backward, reinference, opt, unopt, transformed, native_interpreter, current_level, remarks)
82
+ return new (forward, backward, opt, unopt, transformed, native_interpreter, current_level, remarks)
85
83
end
86
84
end
87
85
@@ -90,8 +88,6 @@ raise_level(interp::ADInterpreter) = change_level(interp, interp.current_level +
90
88
lower_level (interp:: ADInterpreter ) = change_level (interp, interp. current_level - 1 )
91
89
92
90
disable_forward (interp:: ADInterpreter ) = ADInterpreter (interp; forward= false )
93
- disable_reinference (interp:: ADInterpreter ) = ADInterpreter (interp; reinference= false )
94
- enable_reinference (interp:: ADInterpreter ) = ADInterpreter (interp; reinference= true )
95
91
96
92
function Cthulhu. get_optimized_codeinst (interp:: ADInterpreter , curs:: ADCursor )
97
93
@show curs
@@ -120,7 +116,7 @@ function Cthulhu.lookup(interp::ADInterpreter, curs::ADCursor, optimize::Bool; a
120
116
opt = codeinst. inferred
121
117
if opt != = nothing
122
118
opt = opt:: Cthulhu.OptimizedSource
123
- src = Core . Compiler . copy (opt. ir)
119
+ src = CC . copy (opt. ir)
124
120
codeinf = opt. src
125
121
infos = src. stmts. info
126
122
slottypes = src. argtypes
@@ -162,7 +158,6 @@ function Cthulhu.custom_toggles(interp::ADInterpreter)
162
158
end
163
159
164
160
# TODO : Something is going very wrong here
165
- using Core. Compiler: Effects, OptimizationState
166
161
function Cthulhu. get_effects (interp:: ADInterpreter , mi:: MethodInstance , opt:: Bool )
167
162
if haskey (interp. unopt[0 ], mi)
168
163
return interp. unopt[0 ][mi]. effects
@@ -171,7 +166,7 @@ function Cthulhu.get_effects(interp::ADInterpreter, mi::MethodInstance, opt::Boo
171
166
end
172
167
end
173
168
174
- function Core . Compiler . is_same_frame (interp:: ADInterpreter , linfo:: MethodInstance , frame:: InferenceState )
169
+ function CC . is_same_frame (interp:: ADInterpreter , linfo:: MethodInstance , frame:: InferenceState )
175
170
linfo === frame. linfo || return false
176
171
return interp. current_level === frame. interp. current_level
177
172
end
@@ -224,7 +219,7 @@ function Cthulhu.navigate(curs::ADCursor, callsite::Cthulhu.Callsite)
224
219
return ADCursor (curs. level, Cthulhu. get_mi (callsite))
225
220
end
226
221
227
- function Cthulhu. process_info (interp:: ADInterpreter , @nospecialize (info:: Core.Compiler .CallInfo ), argtypes:: Cthulhu.ArgTypes , @nospecialize (rt), optimize:: Bool )
222
+ function Cthulhu. process_info (interp:: ADInterpreter , @nospecialize (info:: CC .CallInfo ), argtypes:: Cthulhu.ArgTypes , @nospecialize (rt), optimize:: Bool )
228
223
if isa (info, RecurseInfo)
229
224
newargtypes = argtypes[2 : end ]
230
225
callinfos = Cthulhu. process_info (interp, info. info, newargtypes, Cthulhu. unwrapType (widenconst (rt)), optimize)
@@ -252,33 +247,33 @@ function Cthulhu.process_info(interp::ADInterpreter, @nospecialize(info::Core.Co
252
247
elseif isa (info, CompClosInfo)
253
248
return Any[CompClosCallInfo (rt)]
254
249
end
255
- return invoke (Cthulhu. process_info, Tuple{AbstractInterpreter, Core . Compiler . CallInfo, Cthulhu. ArgTypes, Any, Bool},
250
+ return invoke (Cthulhu. process_info, Tuple{AbstractInterpreter, CC . CallInfo, Cthulhu. ArgTypes, Any, Bool},
256
251
interp, info, argtypes, rt, optimize)
257
252
end
258
253
259
- Core . Compiler . InferenceParams (ei:: ADInterpreter ) = InferenceParams (ei. native_interpreter)
260
- Core . Compiler . OptimizationParams (ei:: ADInterpreter ) = OptimizationParams (ei. native_interpreter)
261
- Core . Compiler . get_world_counter (ei:: ADInterpreter ) = get_world_counter (ei. native_interpreter)
262
- Core . Compiler . get_inference_cache (ei:: ADInterpreter ) = get_inference_cache (ei. native_interpreter)
254
+ CC . InferenceParams (ei:: ADInterpreter ) = InferenceParams (ei. native_interpreter)
255
+ CC . OptimizationParams (ei:: ADInterpreter ) = OptimizationParams (ei. native_interpreter)
256
+ CC . get_world_counter (ei:: ADInterpreter ) = get_world_counter (ei. native_interpreter)
257
+ CC . get_inference_cache (ei:: ADInterpreter ) = get_inference_cache (ei. native_interpreter)
263
258
264
259
# No need to do any locking since we're not putting our results into the runtime cache
265
- Core . Compiler . lock_mi_inference (ei:: ADInterpreter , mi:: MethodInstance ) = nothing
266
- Core . Compiler . unlock_mi_inference (ei:: ADInterpreter , mi:: MethodInstance ) = nothing
260
+ CC . lock_mi_inference (ei:: ADInterpreter , mi:: MethodInstance ) = nothing
261
+ CC . unlock_mi_inference (ei:: ADInterpreter , mi:: MethodInstance ) = nothing
267
262
268
263
struct CodeInfoView
269
264
d:: Dict{MethodInstance, Any}
270
265
end
271
266
272
- function Core . Compiler . code_cache (ei:: ADInterpreter )
267
+ function CC . code_cache (ei:: ADInterpreter )
273
268
while ei. current_level > lastindex (ei. opt)
274
269
push! (ei. opt, Dict {MethodInstance, Any} ())
275
270
end
276
271
ei. opt[ei. current_level]
277
272
end
278
- Core . Compiler . may_optimize (ei:: ADInterpreter ) = true
279
- Core . Compiler . may_compress (ei:: ADInterpreter ) = false
280
- Core . Compiler . may_discard_trees (ei:: ADInterpreter ) = false
281
- function Core . Compiler . get (view:: CodeInfoView , mi:: MethodInstance , default)
273
+ CC . may_optimize (ei:: ADInterpreter ) = true
274
+ CC . may_compress (ei:: ADInterpreter ) = false
275
+ CC . may_discard_trees (ei:: ADInterpreter ) = false
276
+ function CC . get (view:: CodeInfoView , mi:: MethodInstance , default)
282
277
r = get (view. d, mi, nothing )
283
278
if r === nothing
284
279
return default
@@ -298,23 +293,23 @@ end
298
293
Cthulhu. get_remarks (interp:: ADInterpreter , key:: Union{MethodInstance,InferenceResult} ) = get (interp. remarks[interp. current_level], key, nothing )
299
294
300
295
#=
301
- function Core.Compiler .const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
296
+ function CC .const_prop_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
302
297
return true
303
298
end
304
299
=#
305
300
306
- function Core . Compiler . finish (state:: InferenceState , interp:: ADInterpreter )
307
- res = @invoke Core . Compiler . finish (state:: InferenceState , interp:: AbstractInterpreter )
308
- key = Core . Compiler . any (state. result. overridden_by_const) ? state. result : state. linfo
301
+ function CC . finish (state:: InferenceState , interp:: ADInterpreter )
302
+ res = @invoke CC . finish (state:: InferenceState , interp:: AbstractInterpreter )
303
+ key = CC . any (state. result. overridden_by_const) ? state. result : state. linfo
309
304
interp. unopt[interp. current_level][key] = Cthulhu. InferredSource (
310
305
copy (state. src),
311
306
copy (state. stmt_info),
312
- isdefined (Core . Compiler, :Effects ) ? state. ipo_effects : nothing ,
307
+ state. ipo_effects,
313
308
state. result. result)
314
309
return res
315
310
end
316
311
317
- function Core . Compiler . transform_result_for_cache (interp:: ADInterpreter ,
312
+ function CC . transform_result_for_cache (interp:: ADInterpreter ,
318
313
linfo:: MethodInstance , valid_worlds:: WorldRange , result:: InferenceResult )
319
314
return Cthulhu. create_cthulhu_source (result. src, result. ipo_effects)
320
315
end
@@ -325,75 +320,62 @@ function CC.inlining_policy(interp::ADInterpreter,
325
320
if isa (info, FRuleCallInfo)
326
321
return nothing
327
322
end
328
- if isdefined (CC, :SemiConcreteResult ) && isa (src, CC. SemiConcreteResult)
323
+ if isa (src, CC. SemiConcreteResult)
329
324
return src
330
325
end
331
326
@assert isa (src, Cthulhu. OptimizedSource) || isnothing (src)
332
327
if isa (src, Cthulhu. OptimizedSource)
333
328
if CC. is_stmt_inline (stmt_flag) || src. isinlineable
334
329
return src. ir
335
330
end
336
- else
337
- # the default inlining policy may try additional effor to find the source in a local cache
338
- return @invoke CC. inlining_policy (interp:: AbstractInterpreter ,
339
- nothing , info:: CC.CallInfo , stmt_flag:: UInt8 , mi:: MethodInstance , argtypes:: Vector{Any} )
331
+ return nothing
340
332
end
341
- return nothing
333
+ # the default inlining policy may try additional effor to find the source in a local cache
334
+ return @invoke CC. inlining_policy (interp:: AbstractInterpreter ,
335
+ nothing , info:: CC.CallInfo , stmt_flag:: UInt8 , mi:: MethodInstance , argtypes:: Vector{Any} )
342
336
end
343
337
344
- function dummy () end
345
- const dummym = first (methods (dummy))
346
-
338
+ # TODO remove this overload once https://github.com/JuliaLang/julia/pull/49191 gets merged
347
339
function CC. abstract_call_gf_by_type (interp:: ADInterpreter , @nospecialize (f),
348
340
arginfo:: ArgInfo , si:: StmtInfo , @nospecialize (atype),
349
- sv:: IRCode , max_methods:: Int )
350
-
351
- if interp. reinference
352
- # Create a dummy inference state to serve as the root
353
- # TODO : This is terrible - how can we refactor this to do better?
354
- mi = CC. specialize_method (dummym, Tuple{typeof (dummy)}, Core. svec ())
355
- result = InferenceResult (mi)
356
- interp′ = disable_forward (disable_reinference (interp))
357
- sv′ = InferenceState (result, :no , interp′)
358
- r = abstract_call_gf_by_type (interp′, f, arginfo, si, atype, sv′, - 1 )
359
- return r
360
- end
361
-
362
- return CallMeta (Any, CC. Effects (), CC. NoCallInfo ())
341
+ sv:: IRInterpretationState , max_methods:: Int )
342
+ return @invoke CC. abstract_call_gf_by_type (interp:: AbstractInterpreter , f:: Any ,
343
+ arginfo:: ArgInfo , si:: StmtInfo , atype:: Any ,
344
+ sv:: CC.AbsIntState , max_methods:: Int )
363
345
end
364
346
365
347
#=
366
- function Core.Compiler .optimize(interp::ADInterpreter, opt::OptimizationState,
348
+ function CC .optimize(interp::ADInterpreter, opt::OptimizationState,
367
349
params::OptimizationParams, caller::InferenceResult)
368
350
369
351
# TODO : Enable some amount of inlining
370
352
#@timeit "optimizer" ir = run_passes(opt.src, opt, caller)
371
353
372
354
sv = opt
373
355
ci = opt.src
374
- ir = Core.Compiler .convert_to_ircode(ci, sv)
375
- ir = Core.Compiler .slot2reg(ir, ci, sv)
356
+ ir = CC .convert_to_ircode(ci, sv)
357
+ ir = CC .slot2reg(ir, ci, sv)
376
358
# TODO : Domsorting can produce an updated domtree - no need to recompute here
377
- ir = Core.Compiler .compact!(ir)
378
- return Core.Compiler .finish(interp, opt, params, ir, caller)
359
+ ir = CC .compact!(ir)
360
+ return CC .finish(interp, opt, params, ir, caller)
379
361
end
380
362
=#
381
363
382
- function Core . Compiler . finish! (interp:: ADInterpreter , caller:: InferenceResult )
364
+ function CC . finish! (interp:: ADInterpreter , caller:: InferenceResult )
383
365
effects = caller. ipo_effects
384
366
caller. src = Cthulhu. create_cthulhu_source (caller. src, effects)
385
367
end
386
368
387
369
function ir2codeinst (ir:: IRCode , inst:: CodeInstance , ci:: CodeInfo )
388
370
CodeInstance (inst. def, inst. rettype, isdefined (inst, :rettype_const ) ? inst. rettype_const : nothing ,
389
- Cthulhu. OptimizedSource (Core . Compiler . copy (ir), ci, inst. inferred. isinlineable, Core . Compiler . decode_effects (inst. purity_bits)),
371
+ Cthulhu. OptimizedSource (CC . copy (ir), ci, inst. inferred. isinlineable, CC . decode_effects (inst. purity_bits)),
390
372
Int32 (0 ), inst. min_world, inst. max_world, inst. ipo_purity_bits, inst. purity_bits,
391
373
inst. argescapes, inst. relocatability)
392
374
end
393
375
394
376
using Core: OpaqueClosure
395
377
function codegen (interp:: ADInterpreter , curs:: ADCursor , cache= Dict {ADCursor, OpaqueClosure} ())
396
- ir = Core . Compiler . copy (Cthulhu. get_optimized_codeinst (interp, curs). inferred. ir)
378
+ ir = CC . copy (Cthulhu. get_optimized_codeinst (interp, curs). inferred. ir)
397
379
codeinst = interp. opt[curs. level][curs. mi]
398
380
ci = codeinst. inferred. src
399
381
if curs. level >= 1
0 commit comments