Skip to content

Commit 3b6e14a

Browse files
authored
Invalidate caches on refresh() (#24)
* Record `factory` as an edge for all processed `CodeInstance`s * Address review comments, fix `CodeInstance`s not being invalidated * Add invalidation edges transitively, don't pass them around * Simplify backedge storage when we know we have only 1 edge
1 parent 9610244 commit 3b6e14a

File tree

7 files changed

+73
-11
lines changed

7 files changed

+73
-11
lines changed

src/analysis/ADAnalyzer.jl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using Compiler
44
using Diffractor
55
using Core: SimpleVector, CodeInstance, Const
66
using Compiler: ArgInfo, StmtInfo, AbstractInterpreter, InferenceParams, OptimizationParams,
7-
AbsIntState, CallInfo, InferenceResult
7+
AbsIntState, CallInfo, InferenceResult, InferenceState
88

99
struct ADCache; end
1010

@@ -17,10 +17,12 @@ AD'd using Diffractor.
1717
struct ADAnalyzer <: Compiler.AbstractInterpreter
1818
world::UInt
1919
inf_cache::Vector{Compiler.InferenceResult}
20+
edges::SimpleVector # additional edges
2021
function ADAnalyzer(;
2122
world::UInt = Base.get_world_counter(),
22-
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[])
23-
new(world, inf_cache)
23+
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[],
24+
edges = Compiler.empty_edges)
25+
new(world, inf_cache, edges)
2426
end
2527
end
2628

@@ -60,6 +62,11 @@ struct AnalyzedSource
6062
inline_cost::Compiler.InlineCostType
6163
end
6264

65+
@override function Compiler.result_edges(interp::ADAnalyzer, caller::InferenceState)
66+
edges = @invoke Compiler.result_edges(interp::AbstractInterpreter, caller::InferenceState)
67+
Core.svec(edges..., interp.edges...)
68+
end
69+
6370
@override function Compiler.transform_result_for_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
6471
ir = result.src.optresult.ir
6572
params = Compiler.OptimizationParams(interp)
@@ -88,12 +95,16 @@ end
8895
error(lazy"Could not find single target method for `$sig`")
8996
end
9097

91-
function ad_typeinf(world, tt; force_inline_all=false)
92-
@assert !force_inline_all
93-
interp = ADAnalyzer(;world)
98+
function get_method_instance(@nospecialize(tt), world)
9499
match = Base._methods_by_ftype(tt, 1, world)
95100
isempty(match) && single_match_error(tt)
96101
match = only(match)
97102
mi = Compiler.specialize_method(match)
103+
end
104+
105+
function ad_typeinf(world, tt; force_inline_all=false, edges=Compiler.empty_edges)
106+
@assert !force_inline_all
107+
interp = ADAnalyzer(; world, edges)
108+
mi = get_method_instance(tt, world)
98109
ci = Compiler.typeinf_ext(interp, mi, Compiler.SOURCE_MODE_ABI)
99110
end

src/analysis/refiner.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,4 +415,4 @@ is_any_incidence(@nospecialize args...) = any(@nospecialize(x)->isa(x, Incidence
415415
f::Any, bargtypes::Vector{Any}, sv::Union{AbsIntState,Nothing})
416416

417417
return rt
418-
end
418+
end

src/interface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ This is the compile-time entry point for DAECompiler code generation. It drives
2222
"""
2323
function factory_gen(world::UInt, source::Method, @nospecialize(_gen), settings, @nospecialize(fT))
2424
settings = settings.parameters[1]
25+
factory_mi = get_method_instance(Tuple{typeof(factory),Val{settings},typeof(fT)}, world)
2526

2627
# First, perform ordinary type inference, under the assumption that we may need to AD
2728
# parts of the function later.
28-
ci = ad_typeinf(world, Tuple{fT}; force_inline_all=settings.force_inline_all)
29+
ci = ad_typeinf(world, Tuple{fT}; force_inline_all=settings.force_inline_all, edges=Core.svec(factory_mi))
2930

3031
# Perform or lookup DAECompiler specific analysis for this system.
3132
result = structural_analysis!(ci, world)

src/transform/codegen/init_factory.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,4 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI
102102
argt, Vector{Float64}, Vector{Float64}, true, oc_source_method, sicm), Core.OpaqueClosure, line), true)
103103

104104
return new_oc
105-
end
105+
end

src/transform/common.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,13 @@ function ir_to_src(ir::IRCode)
6464
end
6565

6666
function cache_dae_ci!(old_ci, src, debuginfo, abi, owner)
67+
mi = old_ci.def
68+
edges = Core.svec(mi)
6769
daef_ci = CodeInstance(abi === nothing ? old_ci.def : Core.ABIOverride(abi, old_ci.def), owner, Tuple, Union{}, nothing, src, Int32(0),
6870
UInt(1)#=ci.min_world=#, old_ci.max_world, old_ci.ipo_purity_bits,
69-
nothing, debuginfo, Compiler.empty_edges)
70-
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), old_ci.def, daef_ci)
71+
nothing, debuginfo, edges)
72+
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), mi, nothing, daef_ci)
73+
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, daef_ci)
7174
return daef_ci
7275
end
7376

test/invalidation.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
using Test
2+
using SciMLBase
3+
using Sundials
4+
using Compiler
5+
using Core.IR
6+
using DAECompiler
7+
using DAECompiler: get_method_instance, refresh
8+
using DAECompiler.Intrinsics
9+
10+
function get_cached_code_instances(mi::MethodInstance)
11+
ci = mi.cache
12+
cached = CodeInstance[ci]
13+
while isdefined(ci, :next)
14+
push!(cached, ci.next)
15+
ci = ci.next
16+
end
17+
return cached
18+
end
19+
20+
trivialeq!() = always!(ddt(continuous()))
21+
22+
@testset "Invalidation" begin
23+
mi = get_method_instance(Tuple{typeof(trivialeq!)}, Base.get_world_counter())
24+
25+
ci = DAECompiler.find_matching_ci(ci->ci.owner == DAECompiler.StructureCache(), mi, Base.get_world_counter())
26+
@test ci === nothing
27+
28+
solve(DAECProblem(trivialeq!, (1,) .=> 1.), IDA())
29+
ci = DAECompiler.find_matching_ci(ci->ci.owner == DAECompiler.StructureCache(), mi, Base.get_world_counter())
30+
@test ci !== nothing
31+
32+
cached = get_cached_code_instances(mi)
33+
world_before = Base.get_world_counter()
34+
@test getproperty.(cached, :max_world) == fill(typemax(UInt), length(cached))
35+
refresh()
36+
@test getproperty.(cached, :max_world) == fill(world_before, length(cached))
37+
38+
ci = DAECompiler.find_matching_ci(ci->ci.owner == DAECompiler.StructureCache(), mi, Base.get_world_counter())
39+
@test ci === nothing
40+
41+
solve(DAECProblem(trivialeq!, (1,) .=> 1.), IDA())
42+
ci = DAECompiler.find_matching_ci(ci->ci.owner == DAECompiler.StructureCache(), mi, Base.get_world_counter())
43+
@test ci !== nothing
44+
45+
@test length(get_cached_code_instances(mi)) == 2 * length(cached)
46+
end;

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ include("basic.jl")
22
include("ipo.jl")
33
include("ssrm.jl")
44
include("regression.jl")
5+
include("invalidation.jl")

0 commit comments

Comments
 (0)