@@ -4,7 +4,7 @@ using Compiler
44using Diffractor
55using Core: SimpleVector, CodeInstance, Const
66using Compiler: ArgInfo, StmtInfo, AbstractInterpreter, InferenceParams, OptimizationParams,
7- AbsIntState, CallInfo, InferenceResult
7+ AbsIntState, CallInfo, InferenceResult, InferenceState
88
99struct ADCache; end
1010
@@ -17,10 +17,12 @@ AD'd using Diffractor.
1717struct ADAnalyzer <: Compiler.AbstractInterpreter
1818 world:: UInt
1919 inf_cache:: Vector{Compiler.InferenceResult}
20+ edges:: SimpleVector # additional edges
2021 function ADAnalyzer (;
2122 world:: UInt = Base. get_world_counter (),
22- inf_cache:: Vector{Compiler.InferenceResult} = Compiler. InferenceResult[])
23- new (world, inf_cache)
23+ inf_cache:: Vector{Compiler.InferenceResult} = Compiler. InferenceResult[],
24+ edges = Compiler. empty_edges)
25+ new (world, inf_cache, edges)
2426 end
2527end
2628
@@ -60,6 +62,11 @@ struct AnalyzedSource
6062 inline_cost:: Compiler.InlineCostType
6163end
6264
65+ @override function Compiler. result_edges (interp:: ADAnalyzer , caller:: InferenceState )
66+ edges = @invoke Compiler. result_edges (interp:: AbstractInterpreter , caller:: InferenceState )
67+ Core. svec (edges... , interp. edges... )
68+ end
69+
6370@override function Compiler. transform_result_for_cache (interp:: ADAnalyzer , result:: InferenceResult , edges:: SimpleVector )
6471 ir = result. src. optresult. ir
6572 params = Compiler. OptimizationParams (interp)
8895 error (lazy " Could not find single target method for `$sig`" )
8996end
9097
91- function ad_typeinf (world, tt; force_inline_all= false )
92- @assert ! force_inline_all
93- interp = ADAnalyzer (;world)
98+ function get_method_instance (@nospecialize (tt), world)
9499 match = Base. _methods_by_ftype (tt, 1 , world)
95100 isempty (match) && single_match_error (tt)
96101 match = only (match)
97102 mi = Compiler. specialize_method (match)
103+ end
104+
105+ function ad_typeinf (world, tt; force_inline_all= false , edges= Compiler. empty_edges)
106+ @assert ! force_inline_all
107+ interp = ADAnalyzer (; world, edges)
108+ mi = get_method_instance (tt, world)
98109 ci = Compiler. typeinf_ext (interp, mi, Compiler. SOURCE_MODE_ABI)
99110end
0 commit comments