Skip to content

Commit eed967d

Browse files
committed
Add invalidation edges transitively, don't pass them around
1 parent f3d890c commit eed967d

File tree

10 files changed

+42
-45
lines changed

10 files changed

+42
-45
lines changed

src/analysis/refiner.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ struct StructuralRefiner <: Compiler.AbstractInterpreter
1010
var_to_diff::DiffGraph
1111
varkinds::Vector{Intrinsics.VarKind}
1212
varclassification::Vector{VarEqClassification}
13-
edges::SimpleVector
1413
end
1514

1615
struct StructureCache; end
@@ -50,7 +49,7 @@ Compiler.cache_owner(::StructuralRefiner) = StructureCache()
5049
end
5150

5251
callee_codeinst = invokee
53-
callee_result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(interp), interp.edges)
52+
callee_result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(interp))
5453

5554
if isa(callee_result, UncompilableIPOResult) || isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
5655
# If this is uncompilable, we will be notfiying the user in the outer loop - here we just ignore it

src/analysis/structural.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ function find_matching_ci(predicate, mi::MethodInstance, world::UInt)
1616
return nothing
1717
end
1818

19-
function structural_analysis!(ci::CodeInstance, world::UInt, edges::SimpleVector)
19+
function structural_analysis!(ci::CodeInstance, world::UInt)
2020
# Check if we have aleady done this work - if so return the cached result
2121
result_ci = find_matching_ci(ci->ci.owner == StructureCache(), ci.def, world)
2222
if result_ci !== nothing
2323
return result_ci.inferred
2424
end
2525

26-
result = _structural_analysis!(ci, world, edges)
26+
result = _structural_analysis!(ci, world)
2727
# TODO: The world bounds might have been narrowed
28-
cache_dae_ci!(ci, result, nothing, nothing, StructureCache(), edges)
28+
cache_dae_ci!(ci, result, nothing, nothing, StructureCache())
2929

3030
return result
3131
end
3232

33-
function _structural_analysis!(ci::CodeInstance, world::UInt, edges::SimpleVector)
33+
function _structural_analysis!(ci::CodeInstance, world::UInt)
3434
# Variables
3535
var_to_diff = DiffGraph(0)
3636
varclassification = VarEqClassification[]
@@ -66,7 +66,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, edges::SimpleVecto
6666
ir = copy(ci.inferred.ir)
6767

6868
# Allocate variable and equation numbers of any incoming arguments
69-
refiner = StructuralRefiner(world, var_to_diff, varkinds, varclassification, Core.svec(ci))
69+
refiner = StructuralRefiner(world, var_to_diff, varkinds, varclassification)
7070
argtypes = Any[make_argument_lattice_elem(Compiler.typeinf_lattice(refiner), Argument(i), argt, add_variable!, add_equation!, add_scope!) for (i, argt) in enumerate(ir.argtypes)]
7171
nexternalvars = length(var_to_diff)
7272
nexternaleqs = length(eqssas)
@@ -291,7 +291,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt, edges::SimpleVecto
291291
(; result, mapping) = info
292292
else
293293
callee_codeinst = stmt.args[1]
294-
result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(refiner), edges)
294+
result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(refiner))
295295

296296
if isa(result, UncompilableIPOResult)
297297
# TODO: Stack trace?

src/interface.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,9 @@ function factory_gen(world::UInt, source::Method, @nospecialize(_gen), settings,
2727
# First, perform ordinary type inference, under the assumption that we may need to AD
2828
# parts of the function later.
2929
ci = ad_typeinf(world, Tuple{fT}; force_inline_all=settings.force_inline_all, edges=Core.svec(factory_mi))
30-
edges = Core.svec(ci)
3130

3231
# Perform or lookup DAECompiler specific analysis for this system.
33-
result = structural_analysis!(ci, world, edges)
32+
result = structural_analysis!(ci, world)
3433

3534
if isa(result, UncompilableIPOResult)
3635
return Base.generated_body_to_codeinfo(
@@ -44,19 +43,19 @@ function factory_gen(world::UInt, source::Method, @nospecialize(_gen), settings,
4443
(diff_key, init_key) = top_level_state_selection!(tstate)
4544

4645
if settings.mode in (DAE, DAENoInit, ODE, ODENoInit)
47-
tearing_schedule!(tstate, ci, diff_key, world, edges)
46+
tearing_schedule!(tstate, ci, diff_key, world)
4847
end
4948
if settings.mode in (InitUncompress, DAE, ODE)
50-
tearing_schedule!(tstate, ci, init_key, world, edges)
49+
tearing_schedule!(tstate, ci, init_key, world)
5150
end
5251

5352
# Generate the IR implementation of `factory`, returning the DAEFunction/ODEFunction
5453
if settings.mode in (DAE, DAENoInit)
55-
ir_factory = dae_factory_gen(tstate, ci, diff_key, world, edges, settings.mode == DAE ? init_key : nothing)
54+
ir_factory = dae_factory_gen(tstate, ci, diff_key, world, settings.mode == DAE ? init_key : nothing)
5655
elseif settings.mode in (ODE, ODENoInit)
57-
ir_factory = ode_factory_gen(tstate, ci, diff_key, world, edges, settings.mode == ODE ? init_key : nothing)
56+
ir_factory = ode_factory_gen(tstate, ci, diff_key, world, settings.mode == ODE ? init_key : nothing)
5857
elseif settings.mode == InitUncompress
59-
ir_factory = init_uncompress_gen(result, ci, init_key, diff_key, world, edges)
58+
ir_factory = init_uncompress_gen(result, ci, init_key, diff_key, world)
6059
else
6160
return :(error("Unknown generation mode: $(settings.mode)"))
6261
end

src/transform/codegen/dae_factory.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
```
6161
6262
"""
63-
function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, edges::SimpleVector, init_key::Union{TornCacheKey, Nothing})
63+
function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, init_key::Union{TornCacheKey, Nothing})
6464
result = state.result
6565
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world)
6666
torn_ir = torn_ci.inferred
@@ -88,7 +88,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
8888

8989
argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
9090

91-
daef_ci = rhs_finish!(state, ci, key, world, 1, edges)
91+
daef_ci = rhs_finish!(state, ci, key, world, 1)
9292

9393
# Create a small opaque closure to adapt from SciML ABI to our own internal
9494
# ABI
@@ -188,7 +188,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
188188
differential_states = Bool[v in key.diff_states for v in all_states]
189189

190190
if init_key !== nothing
191-
initf = init_uncompress_gen!(compact, result, ci, init_key, key, world, edges)
191+
initf = init_uncompress_gen!(compact, result, ci, init_key, key, world)
192192
daef = insert_node_here!(compact, NewInstruction(Expr(:call, make_daefunction, new_oc, initf),
193193
DAEFunction, line), true)
194194
else

src/transform/codegen/init_factory.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11

2-
function init_uncompress_gen(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, edges::SimpleVector)
2+
function init_uncompress_gen(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt)
33
ir_factory = copy(result.ir)
44
pushfirst!(ir_factory.argtypes, Settings)
55
pushfirst!(ir_factory.argtypes, typeof(factory))
66
compact = IncrementalCompact(ir_factory)
77

8-
new_oc = init_uncompress_gen!(compact, result, ci, init_key, diff_key, world, edges)
8+
new_oc = init_uncompress_gen!(compact, result, ci, init_key, diff_key, world)
99
insert_node_here!(compact, NewInstruction(ReturnNode(new_oc), Core.OpaqueClosure, result.ir[SSAValue(1)][:line]), true)
1010

1111
ir_factory = Compiler.finish(compact)
1212

1313
return ir_factory
1414
end
1515

16-
function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, edges::SimpleVector)
16+
function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt)
1717
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == init_key, ci.def, world)
1818
@assert torn_ci !== nothing
1919
torn_ir = torn_ci.inferred
@@ -35,7 +35,7 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI
3535

3636
# (nlsol,)
3737
argt = Tuple{Any}
38-
daef_ci = gen_init_uncompress!(result, ci, init_key, diff_key, world, 1, edges)
38+
daef_ci = gen_init_uncompress!(result, ci, init_key, diff_key, world, 1)
3939

4040
# Create a small opaque closure to adapt from SciML ABI to our own internal
4141
# ABI

src/transform/codegen/init_uncompress.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ struct InitUncompressSpec
99
ordinal::Int
1010
end
1111

12-
function gen_init_uncompress!(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, ordinal::Int, edges::SimpleVector, indexT=Int)
12+
function gen_init_uncompress!(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, ordinal::Int, indexT=Int)
1313
structure = make_structure_from_ipo(result)
1414
tstate = TransformationState(result, structure, copy(result.total_incidence))
15-
return gen_init_uncompress!(tstate, ci, init_key, diff_key, world, ordinal, edges, indexT)
15+
return gen_init_uncompress!(tstate, ci, init_key, diff_key, world, ordinal, indexT)
1616
end
1717

1818
function gen_init_uncompress!(
@@ -22,7 +22,6 @@ function gen_init_uncompress!(
2222
diff_key::TornCacheKey,
2323
world::UInt,
2424
ordinal::Int,
25-
edges::SimpleVector,
2625
indexT=Int)
2726

2827
(; result, structure) = state
@@ -98,8 +97,8 @@ function gen_init_uncompress!(
9897
spec_data = stmt.args[1]
9998
callee_key = stmt.args[1][2]
10099
callee_ordinal = stmt.args[1][end]::Int
101-
callee_result = structural_analysis!(callee_ci, world, edges)
102-
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal, edges)
100+
callee_result = structural_analysis!(callee_ci, world)
101+
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal)
103102
# Allocate a continuous block of variables for all callee alg and diff states
104103

105104
empty!(stmt.args)
@@ -155,7 +154,7 @@ function gen_init_uncompress!(
155154
src = ir_to_src(ir)
156155

157156
abi = Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Vector{Float64}, Float64}
158-
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, InitUncompressSpec(init_key, diff_key, ir_ordinal), edges)
157+
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, InitUncompressSpec(init_key, diff_key, ir_ordinal))
159158
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), daef_ci, src)
160159

161160
push!(cis, daef_ci)

src/transform/codegen/ode_factory.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
```
5353
5454
"""
55-
function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, edges::SimpleVector, init_key::Union{TornCacheKey, Nothing})
55+
function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, init_key::Union{TornCacheKey, Nothing})
5656
result = state.result
5757
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world)
5858
torn_ir = torn_ci.inferred
@@ -75,7 +75,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
7575
sicm = ()
7676
end
7777

78-
odef_ci = rhs_finish!(state, ci, key, world, 1, edges)
78+
odef_ci = rhs_finish!(state, ci, key, world, 1)
7979

8080
# Create a small opaque closure to adapt from SciML ABI to our own internal ABI
8181

@@ -140,7 +140,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
140140
nd = numstates[AssignedDiff] + numstates[UnassignedDiff]
141141
na = numstates[Algebraic] + numstates[AlgebraicDerivative]
142142
mass_matrix = na == 0 ? GlobalRef(LinearAlgebra, :I) : @insert_node_here compact line generate_ode_mass_matrix(nd, na)::Matrix{Float64}
143-
initf = init_key !== nothing ? init_uncompress_gen!(compact, result, ci, init_key, key, world, edges) : nothing
143+
initf = init_key !== nothing ? init_uncompress_gen!(compact, result, ci, init_key, key, world) : nothing
144144
odef = @insert_node_here compact line make_odefunction(new_oc, mass_matrix, initf)::ODEFunction true
145145

146146
odef_and_n = @insert_node_here compact line tuple(odef, nd + na)::Tuple true

src/transform/codegen/rhs.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ function compute_slot_ranges(info::MappingInfo, callee_key, var_assignment, eq_a
6363
return state_ranges
6464
end
6565

66-
function rhs_finish!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, ordinal::Int, edges::SimpleVector, indexT=Int)
66+
function rhs_finish!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, ordinal::Int, indexT=Int)
6767
structure = make_structure_from_ipo(result)
6868
tstate = TransformationState(result, structure, copy(result.total_incidence))
69-
return rhs_finish!(tstate, ci, key, world, ordinal, edges, indexT)
69+
return rhs_finish!(tstate, ci, key, world, ordinal, indexT)
7070
end
7171

7272
function rhs_finish!(
@@ -75,7 +75,6 @@ function rhs_finish!(
7575
key::TornCacheKey,
7676
world::UInt,
7777
ordinal::Int,
78-
edges::SimpleVector,
7978
indexT=Int)
8079

8180
(; result, structure) = state
@@ -146,8 +145,8 @@ function rhs_finish!(
146145
spec_data = stmt.args[1]
147146
callee_key = spec_data[2]
148147
callee_ordinal = spec_data[end]::Int
149-
callee_result = structural_analysis!(callee_ci, world, edges)
150-
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal, edges)
148+
callee_result = structural_analysis!(callee_ci, world)
149+
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal)
151150
# Allocate a continuous block of variables for all callee alg and diff states
152151

153152
empty!(stmt.args)
@@ -219,7 +218,7 @@ function rhs_finish!(
219218
src = ir_to_src(ir)
220219

221220
abi = Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Float64}
222-
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, RHSSpec(key, ir_ordinal), edges)
221+
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, RHSSpec(key, ir_ordinal))
223222
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), daef_ci, src)
224223

225224
push!(cis, daef_ci)

src/transform/common.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ function ir_to_src(ir::IRCode)
6363
return src
6464
end
6565

66-
function cache_dae_ci!(old_ci, src, debuginfo, abi, owner, edges::SimpleVector)
66+
function cache_dae_ci!(old_ci, src, debuginfo, abi, owner)
67+
edges = Core.svec(old_ci)
6768
daef_ci = CodeInstance(abi === nothing ? old_ci.def : Core.ABIOverride(abi, old_ci.def), owner, Tuple, Union{}, nothing, src, Int32(0),
6869
UInt(1)#=ci.min_world=#, old_ci.max_world, old_ci.ipo_purity_bits,
6970
nothing, debuginfo, edges)

src/transform/tearing/schedule.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -536,13 +536,13 @@ function matching_for_key(result::DAEIPOResult, key::TornCacheKey, structure = m
536536
return var_eq_matching
537537
end
538538

539-
function tearing_schedule!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, edges::SimpleVector)
539+
function tearing_schedule!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt)
540540
structure = make_structure_from_ipo(result)
541541
tstate = TransformationState(result, structure, copy(result.total_incidence))
542-
return tearing_schedule!(tstate, ci, key, world, edges)
542+
return tearing_schedule!(tstate, ci, key, world)
543543
end
544544

545-
function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, edges::SimpleVector)
545+
function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt)
546546
result_ci = find_matching_ci(ci->isa(ci.owner, SICMSpec) && ci.owner.key == key, ci.def, world)
547547
if result_ci !== nothing
548548
return result_ci
@@ -709,8 +709,8 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
709709
if isa(callee_codeinst, MethodInstance)
710710
callee_codeinst = Compiler.get(Compiler.code_cache(interp), callee_codeinst, nothing)
711711
end
712-
callee_result = structural_analysis!(callee_codeinst, world, edges)
713-
callee_sicm_ci = tearing_schedule!(callee_result, callee_codeinst, callee_key, world, edges)
712+
callee_result = structural_analysis!(callee_codeinst, world)
713+
callee_sicm_ci = tearing_schedule!(callee_result, callee_codeinst, callee_key, world)
714714

715715
inst[:type] = Any
716716
inst[:flag] = UInt32(0)
@@ -1029,10 +1029,10 @@ function tearing_schedule!(state::TransformationState, ci::CodeInstance, key::To
10291029
debuginfo = src.debuginfo
10301030
end
10311031

1032-
sicm_ci = cache_dae_ci!(ci, src, debuginfo, sig, SICMSpec(key), edges)
1032+
sicm_ci = cache_dae_ci!(ci, src, debuginfo, sig, SICMSpec(key))
10331033
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), sicm_ci, src)
10341034

1035-
torn_ci = cache_dae_ci!(ci, TornIR(ir_sicm, irs), nothing, sig, TornIRSpec(key), edges)
1035+
torn_ci = cache_dae_ci!(ci, TornIR(ir_sicm, irs), nothing, sig, TornIRSpec(key))
10361036

10371037
return sicm_ci
10381038
end

0 commit comments

Comments
 (0)