@@ -4,7 +4,7 @@ using Compiler
44using Diffractor
55using Core: SimpleVector, CodeInstance, Const
66using Compiler: ArgInfo, StmtInfo, AbstractInterpreter, InferenceParams, OptimizationParams,
7- AbsIntState, CallInfo, InferenceResult
7+ AbsIntState, CallInfo, InferenceResult, InferenceState
88
99struct ADCache; end
1010
@@ -17,10 +17,12 @@ AD'd using Diffractor.
1717struct ADAnalyzer <: Compiler.AbstractInterpreter
1818 world:: UInt
1919 inf_cache:: Vector{Compiler.InferenceResult}
20+ edges:: SimpleVector # additional edges
2021 function ADAnalyzer (;
2122 world:: UInt = Base. get_world_counter (),
22- inf_cache:: Vector{Compiler.InferenceResult} = Compiler. InferenceResult[])
23- new (world, inf_cache)
23+ inf_cache:: Vector{Compiler.InferenceResult} = Compiler. InferenceResult[],
24+ edges = Compiler. empty_edges)
25+ new (world, inf_cache, edges)
2426 end
2527end
2628
@@ -60,6 +62,11 @@ struct AnalyzedSource
6062 inline_cost:: Compiler.InlineCostType
6163end
6264
65+ @override function Compiler. result_edges (interp:: ADAnalyzer , caller:: InferenceState )
66+ edges = @invoke Compiler. result_edges (interp:: AbstractInterpreter , caller:: InferenceState )
67+ Core. svec (edges... , interp. edges... )
68+ end
69+
6370@override function Compiler. transform_result_for_cache (interp:: ADAnalyzer , result:: InferenceResult , edges:: SimpleVector )
6471 ir = result. src. optresult. ir
6572 params = Compiler. OptimizationParams (interp)
@@ -95,16 +102,9 @@ function get_method_instance(@nospecialize(tt), world)
95102 mi = Compiler. specialize_method (match)
96103end
97104
98- function ad_typeinf (world, tt; force_inline_all= false , edges= nothing )
105+ function ad_typeinf (world, tt; force_inline_all= false , edges= Compiler . empty_edges )
99106 @assert ! force_inline_all
100- interp = ADAnalyzer (;world)
107+ interp = ADAnalyzer (; world, edges )
101108 mi = get_method_instance (tt, world)
102109 ci = Compiler. typeinf_ext (interp, mi, Compiler. SOURCE_MODE_ABI)
103- if edges != = nothing
104- prev = @atomic ci. edges
105- # XXX : Should we return the extended edges and use them in the other CodeInstances?
106- @atomic ci. edges = Core. svec (prev... , edges... )
107- Compiler. store_backedges (ci, edges)
108- end
109- ci
110110end
0 commit comments