Skip to content

Commit f3d890c

Browse files
committed
Address review comments, fix CodeInstances not being invalidated
1 parent 64bcf7c commit f3d890c

File tree

6 files changed

+28
-21
lines changed

6 files changed

+28
-21
lines changed

src/analysis/ADAnalyzer.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Compiler
44
using Diffractor
55
using Core: SimpleVector, CodeInstance, Const
66
using Compiler: ArgInfo, StmtInfo, AbstractInterpreter, InferenceParams, OptimizationParams,
7-
AbsIntState, CallInfo, InferenceResult
7+
AbsIntState, CallInfo, InferenceResult, InferenceState
88

99
struct ADCache; end
1010

@@ -17,10 +17,12 @@ AD'd using Diffractor.
1717
struct ADAnalyzer <: Compiler.AbstractInterpreter
1818
world::UInt
1919
inf_cache::Vector{Compiler.InferenceResult}
20+
edges::SimpleVector # additional edges
2021
function ADAnalyzer(;
2122
world::UInt = Base.get_world_counter(),
22-
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[])
23-
new(world, inf_cache)
23+
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[],
24+
edges = Compiler.empty_edges)
25+
new(world, inf_cache, edges)
2426
end
2527
end
2628

@@ -60,6 +62,11 @@ struct AnalyzedSource
6062
inline_cost::Compiler.InlineCostType
6163
end
6264

65+
@override function Compiler.result_edges(interp::ADAnalyzer, caller::InferenceState)
66+
edges = @invoke Compiler.result_edges(interp::AbstractInterpreter, caller::InferenceState)
67+
Core.svec(edges..., interp.edges...)
68+
end
69+
6370
@override function Compiler.transform_result_for_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
6471
ir = result.src.optresult.ir
6572
params = Compiler.OptimizationParams(interp)
@@ -95,16 +102,9 @@ function get_method_instance(@nospecialize(tt), world)
95102
mi = Compiler.specialize_method(match)
96103
end
97104

98-
function ad_typeinf(world, tt; force_inline_all=false, edges=nothing)
105+
function ad_typeinf(world, tt; force_inline_all=false, edges=Compiler.empty_edges)
99106
@assert !force_inline_all
100-
interp = ADAnalyzer(;world)
107+
interp = ADAnalyzer(; world, edges)
101108
mi = get_method_instance(tt, world)
102109
ci = Compiler.typeinf_ext(interp, mi, Compiler.SOURCE_MODE_ABI)
103-
if edges !== nothing
104-
prev = @atomic ci.edges
105-
# XXX: Should we return the extended edges and use them in the other CodeInstances?
106-
@atomic ci.edges = Core.svec(prev..., edges...)
107-
Compiler.store_backedges(ci, edges)
108-
end
109-
ci
110110
end

src/analysis/refiner.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ struct StructuralRefiner <: Compiler.AbstractInterpreter
1010
var_to_diff::DiffGraph
1111
varkinds::Vector{Intrinsics.VarKind}
1212
varclassification::Vector{VarEqClassification}
13+
edges::SimpleVector
1314
end
1415

1516
struct StructureCache; end
@@ -49,9 +50,7 @@ Compiler.cache_owner(::StructuralRefiner) = StructureCache()
4950
end
5051

5152
callee_codeinst = invokee
52-
# XXX: We should propagate edges down here too.
53-
edges = Core.svec()
54-
callee_result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(interp), edges)
53+
callee_result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(interp), interp.edges)
5554

5655
if isa(callee_result, UncompilableIPOResult) || isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
5756
# If this is uncompilable, we will be notfiying the user in the outer loop - here we just ignore it

src/analysis/structural.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, edges::SimpleVecto
6666
ir = copy(ci.inferred.ir)
6767

6868
# Allocate variable and equation numbers of any incoming arguments
69-
refiner = StructuralRefiner(world, var_to_diff, varkinds, varclassification)
69+
refiner = StructuralRefiner(world, var_to_diff, varkinds, varclassification, Core.svec(ci))
7070
argtypes = Any[make_argument_lattice_elem(Compiler.typeinf_lattice(refiner), Argument(i), argt, add_variable!, add_equation!, add_scope!) for (i, argt) in enumerate(ir.argtypes)]
7171
nexternalvars = length(var_to_diff)
7272
nexternaleqs = length(eqssas)

src/interface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ This is the compile-time entry point for DAECompiler code generation. It drives
2323
function factory_gen(world::UInt, source::Method, @nospecialize(_gen), settings, @nospecialize(fT))
2424
settings = settings.parameters[1]
2525
factory_mi = get_method_instance(Tuple{typeof(factory),Val{settings},typeof(fT)}, world)
26-
edges = Core.svec(factory_mi)
2726

2827
# First, perform ordinary type inference, under the assumption that we may need to AD
2928
# parts of the function later.
30-
ci = ad_typeinf(world, Tuple{fT}; force_inline_all=settings.force_inline_all, edges)
29+
ci = ad_typeinf(world, Tuple{fT}; force_inline_all=settings.force_inline_all, edges=Core.svec(factory_mi))
30+
edges = Core.svec(ci)
3131

3232
# Perform or lookup DAECompiler specific analysis for this system.
3333
result = structural_analysis!(ci, world, edges)

src/transform/common.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,19 @@ function cache_dae_ci!(old_ci, src, debuginfo, abi, owner, edges::SimpleVector)
6767
daef_ci = CodeInstance(abi === nothing ? old_ci.def : Core.ABIOverride(abi, old_ci.def), owner, Tuple, Union{}, nothing, src, Int32(0),
6868
UInt(1)#=ci.min_world=#, old_ci.max_world, old_ci.ipo_purity_bits,
6969
nothing, debuginfo, edges)
70-
Compiler.store_backedges(daef_ci, edges)
70+
add_backedges_to_callees(daef_ci, edges)
7171
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), old_ci.def, daef_ci)
7272
return daef_ci
7373
end
7474

75+
# Equivalent to `Compiler.store_backedges` in our case, but we allow `caller.def.def` to not be a `Method`.
76+
function add_backedges_to_callees(caller::CodeInstance, edges::SimpleVector)
77+
for edge in edges
78+
isa(edge, CodeInstance) && (edge = edge.def)
79+
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), edge::MethodInstance, nothing, caller)
80+
end
81+
end
82+
7583
function replace_call!(ir::Union{IRCode,IncrementalCompact}, idx::SSAValue, new_call::Expr)
7684
@assert !isa(ir[idx][:inst], PhiNode)
7785
ir[idx][:inst] = new_call

test/invalidation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ trivialeq!() = always!(ddt(continuous()))
3030
@test ci !== nothing
3131

3232
cached = get_cached_code_instances(mi)
33-
@test getproperty.(cached, :max_world) == fill(typemax(UInt), length(cached))
3433
world_before = Base.get_world_counter()
34+
@test getproperty.(cached, :max_world) == fill(typemax(UInt), length(cached))
3535
refresh()
3636
@test getproperty.(cached, :max_world) == fill(world_before, length(cached))
3737

@@ -43,4 +43,4 @@ trivialeq!() = always!(ddt(continuous()))
4343
@test ci !== nothing
4444

4545
@test length(get_cached_code_instances(mi)) == 2 * length(cached)
46-
end
46+
end;

0 commit comments

Comments
 (0)