Skip to content

Commit 620f9f1

Browse files
committed
Record factory as an edge for all processed CodeInstances
1 parent 88454ee commit 620f9f1

File tree

13 files changed

+113
-47
lines changed

13 files changed

+113
-47
lines changed

src/analysis/ADAnalyzer.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,23 @@ end
8888
error(lazy"Could not find single target method for `$sig`")
8989
end
9090

91-
function ad_typeinf(world, tt; force_inline_all=false)
92-
@assert !force_inline_all
93-
interp = ADAnalyzer(;world)
91+
function get_method_instance(@nospecialize(tt), world)
9492
match = Base._methods_by_ftype(tt, 1, world)
9593
isempty(match) && single_match_error(tt)
9694
match = only(match)
9795
mi = Compiler.specialize_method(match)
96+
end
97+
98+
function ad_typeinf(world, tt; force_inline_all=false, edges=nothing)
99+
@assert !force_inline_all
100+
interp = ADAnalyzer(;world)
101+
mi = get_method_instance(tt, world)
98102
ci = Compiler.typeinf_ext(interp, mi, Compiler.SOURCE_MODE_ABI)
103+
if edges !== nothing
104+
prev = @atomic ci.edges
105+
# XXX: Should we return the extended edges and use them in the other CodeInstances?
106+
@atomic ci.edges = Core.svec(prev..., edges...)
107+
Compiler.store_backedges(ci, edges)
108+
end
109+
ci
99110
end

src/analysis/refiner.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ Compiler.cache_owner(::StructuralRefiner) = StructureCache()
4949
end
5050

5151
callee_codeinst = invokee
52-
callee_result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(interp))
52+
# XXX: We should propagate edges down here too.
53+
edges = Core.svec()
54+
callee_result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(interp), edges)
5355

5456
if isa(callee_result, UncompilableIPOResult) || isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
5557
# If this is uncompilable, we will be notfiying the user in the outer loop - here we just ignore it
@@ -375,4 +377,4 @@ is_any_incidence(@nospecialize args...) = any(@nospecialize(x)->isa(x, Incidence
375377
f::Any, bargtypes::Vector{Any}, sv::Union{AbsIntState,Nothing})
376378

377379
return rt
378-
end
380+
end

src/analysis/structural.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@ function find_matching_ci(predicate, mi::MethodInstance, world::UInt)
1616
return nothing
1717
end
1818

19-
function structural_analysis!(ci::CodeInstance, world::UInt)
19+
function structural_analysis!(ci::CodeInstance, world::UInt, edges::SimpleVector)
2020
# Check if we have aleady done this work - if so return the cached result
2121
result_ci = find_matching_ci(ci->ci.owner == StructureCache(), ci.def, world)
2222
if result_ci !== nothing
2323
return result_ci.inferred
2424
end
2525

26-
result = _structural_analysis!(ci, world)
26+
result = _structural_analysis!(ci, world, edges)
2727
# TODO: The world bounds might have been narrowed
28-
cache_dae_ci!(ci, result, nothing, nothing, StructureCache())
28+
cache_dae_ci!(ci, result, nothing, nothing, StructureCache(), edges)
2929

3030
return result
3131
end
3232

33-
function _structural_analysis!(ci::CodeInstance, world::UInt)
33+
function _structural_analysis!(ci::CodeInstance, world::UInt, edges::SimpleVector)
3434
# Variables
3535
var_to_diff = DiffGraph(0)
3636
varclassification = VarEqClassification[]
@@ -291,7 +291,7 @@ function _structural_analysis!(ci::CodeInstance, world::UInt)
291291
(; result, mapping) = info
292292
else
293293
callee_codeinst = stmt.args[1]
294-
result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(refiner))
294+
result = structural_analysis!(callee_codeinst, Compiler.get_inference_world(refiner), edges)
295295

296296
if isa(result, UncompilableIPOResult)
297297
# TODO: Stack trace?

src/interface.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ This is the compile-time entry point for DAECompiler code generation. It drives
2222
"""
2323
function factory_gen(world::UInt, source::Method, @nospecialize(_gen), settings, @nospecialize(fT))
2424
settings = settings.parameters[1]
25+
factory_mi = get_method_instance(Tuple{typeof(factory),Val{settings},typeof(fT)}, world)
26+
edges = Core.svec(factory_mi)
2527

2628
# First, perform ordinary type inference, under the assumption that we may need to AD
2729
# parts of the function later.
28-
ci = ad_typeinf(world, Tuple{fT}; force_inline_all=settings.force_inline_all)
30+
ci = ad_typeinf(world, Tuple{fT}; force_inline_all=settings.force_inline_all, edges)
2931

3032
# Perform or lookup DAECompiler specific analysis for this system.
31-
result = structural_analysis!(ci, world)
33+
result = structural_analysis!(ci, world, edges)
3234

3335
if isa(result, UncompilableIPOResult)
3436
return Base.generated_body_to_codeinfo(
@@ -42,19 +44,19 @@ function factory_gen(world::UInt, source::Method, @nospecialize(_gen), settings,
4244
(diff_key, init_key) = top_level_state_selection!(tstate)
4345

4446
if settings.mode in (DAE, DAENoInit, ODE, ODENoInit)
45-
tearing_schedule!(tstate, ci, diff_key, world)
47+
tearing_schedule!(tstate, ci, diff_key, world, edges)
4648
end
4749
if settings.mode in (InitUncompress, DAE, ODE)
48-
tearing_schedule!(tstate, ci, init_key, world)
50+
tearing_schedule!(tstate, ci, init_key, world, edges)
4951
end
5052

5153
# Generate the IR implementation of `factory`, returning the DAEFunction/ODEFunction
5254
if settings.mode in (DAE, DAENoInit)
53-
ir_factory = dae_factory_gen(tstate, ci, diff_key, world, settings.mode == DAE ? init_key : nothing)
55+
ir_factory = dae_factory_gen(tstate, ci, diff_key, world, edges, settings.mode == DAE ? init_key : nothing)
5456
elseif settings.mode in (ODE, ODENoInit)
55-
ir_factory = ode_factory_gen(tstate, ci, diff_key, world, settings.mode == ODE ? init_key : nothing)
57+
ir_factory = ode_factory_gen(tstate, ci, diff_key, world, edges, settings.mode == ODE ? init_key : nothing)
5658
elseif settings.mode == InitUncompress
57-
ir_factory = init_uncompress_gen(result, ci, init_key, diff_key, world)
59+
ir_factory = init_uncompress_gen(result, ci, init_key, diff_key, world, edges)
5860
else
5961
return :(error("Unknown generation mode: $(settings.mode)"))
6062
end

src/transform/codegen/dae_factory.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
```
6161
6262
"""
63-
function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, init_key::Union{TornCacheKey, Nothing})
63+
function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, edges::SimpleVector, init_key::Union{TornCacheKey, Nothing})
6464
result = state.result
6565
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world)
6666
torn_ir = torn_ci.inferred
@@ -88,7 +88,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
8888

8989
argt = Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}
9090

91-
daef_ci = rhs_finish!(state, ci, key, world, 1)
91+
daef_ci = rhs_finish!(state, ci, key, world, 1, edges)
9292

9393
# Create a small opaque closure to adapt from SciML ABI to our own internal
9494
# ABI
@@ -188,7 +188,7 @@ function dae_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
188188
differential_states = Bool[v in key.diff_states for v in all_states]
189189

190190
if init_key !== nothing
191-
initf = init_uncompress_gen!(compact, result, ci, init_key, key, world)
191+
initf = init_uncompress_gen!(compact, result, ci, init_key, key, world, edges)
192192
daef = insert_node_here!(compact, NewInstruction(Expr(:call, make_daefunction, new_oc, initf),
193193
DAEFunction, line), true)
194194
else

src/transform/codegen/init_factory.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11

2-
function init_uncompress_gen(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt)
2+
function init_uncompress_gen(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, edges::SimpleVector)
33
ir_factory = copy(result.ir)
44
pushfirst!(ir_factory.argtypes, Settings)
55
pushfirst!(ir_factory.argtypes, typeof(factory))
66
compact = IncrementalCompact(ir_factory)
77

8-
new_oc = init_uncompress_gen!(compact, result, ci, init_key, diff_key, world)
8+
new_oc = init_uncompress_gen!(compact, result, ci, init_key, diff_key, world, edges)
99
insert_node_here!(compact, NewInstruction(ReturnNode(new_oc), Core.OpaqueClosure, result.ir[SSAValue(1)][:line]), true)
1010

1111
ir_factory = Compiler.finish(compact)
1212

1313
return ir_factory
1414
end
1515

16-
function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt)
16+
function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, edges::SimpleVector)
1717
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == init_key, ci.def, world)
1818
@assert torn_ci !== nothing
1919
torn_ir = torn_ci.inferred
@@ -35,7 +35,7 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI
3535

3636
# (nlsol,)
3737
argt = Tuple{Any}
38-
daef_ci = gen_init_uncompress!(result, ci, init_key, diff_key, world, 1)
38+
daef_ci = gen_init_uncompress!(result, ci, init_key, diff_key, world, 1, edges)
3939

4040
# Create a small opaque closure to adapt from SciML ABI to our own internal
4141
# ABI
@@ -102,4 +102,4 @@ function init_uncompress_gen!(compact::Compiler.IncrementalCompact, result::DAEI
102102
argt, Vector{Float64}, Vector{Float64}, true, oc_source_method, sicm), Core.OpaqueClosure, line), true)
103103

104104
return new_oc
105-
end
105+
end

src/transform/codegen/init_uncompress.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ struct InitUncompressSpec
99
ordinal::Int
1010
end
1111

12-
function gen_init_uncompress!(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, ordinal::Int, indexT=Int)
12+
function gen_init_uncompress!(result::DAEIPOResult, ci::CodeInstance, init_key::TornCacheKey, diff_key::TornCacheKey, world::UInt, ordinal::Int, edges::SimpleVector, indexT=Int)
1313
structure = make_structure_from_ipo(result)
1414
tstate = TransformationState(result, structure, copy(result.total_incidence))
15-
return gen_init_uncompress!(tstate, ci, init_key, diff_key, world, ordinal, indexT)
15+
return gen_init_uncompress!(tstate, ci, init_key, diff_key, world, ordinal, edges, indexT)
1616
end
1717

1818
function gen_init_uncompress!(
@@ -22,6 +22,7 @@ function gen_init_uncompress!(
2222
diff_key::TornCacheKey,
2323
world::UInt,
2424
ordinal::Int,
25+
edges::SimpleVector,
2526
indexT=Int)
2627

2728
(; result, structure) = state
@@ -97,8 +98,8 @@ function gen_init_uncompress!(
9798
spec_data = stmt.args[1]
9899
callee_key = stmt.args[1][2]
99100
callee_ordinal = stmt.args[1][end]::Int
100-
callee_result = structural_analysis!(callee_ci, world)
101-
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal)
101+
callee_result = structural_analysis!(callee_ci, world, edges)
102+
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal, edges)
102103
# Allocate a continuous block of variables for all callee alg and diff states
103104

104105
empty!(stmt.args)
@@ -154,7 +155,8 @@ function gen_init_uncompress!(
154155
src = ir_to_src(ir)
155156

156157
abi = Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Vector{Float64}, Float64}
157-
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, InitUncompressSpec(init_key, diff_key, ir_ordinal))
158+
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, InitUncompressSpec(init_key, diff_key, ir_ordinal), edges)
159+
158160
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), daef_ci, src)
159161

160162
push!(cis, daef_ci)

src/transform/codegen/ode_factory.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ end
5252
```
5353
5454
"""
55-
function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, init_key::Union{TornCacheKey, Nothing})
55+
function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::TornCacheKey, world::UInt, edges::SimpleVector, init_key::Union{TornCacheKey, Nothing})
5656
result = state.result
5757
torn_ci = find_matching_ci(ci->isa(ci.owner, TornIRSpec) && ci.owner.key == key, ci.def, world)
5858
torn_ir = torn_ci.inferred
@@ -75,7 +75,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
7575
sicm = ()
7676
end
7777

78-
odef_ci = rhs_finish!(state, ci, key, world, 1)
78+
odef_ci = rhs_finish!(state, ci, key, world, 1, edges)
7979

8080
# Create a small opaque closure to adapt from SciML ABI to our own internal ABI
8181

@@ -140,7 +140,7 @@ function ode_factory_gen(state::TransformationState, ci::CodeInstance, key::Torn
140140
nd = numstates[AssignedDiff] + numstates[UnassignedDiff]
141141
na = numstates[Algebraic] + numstates[AlgebraicDerivative]
142142
mass_matrix = na == 0 ? GlobalRef(LinearAlgebra, :I) : @insert_node_here compact line generate_ode_mass_matrix(nd, na)::Matrix{Float64}
143-
initf = init_key !== nothing ? init_uncompress_gen!(compact, result, ci, init_key, key, world) : nothing
143+
initf = init_key !== nothing ? init_uncompress_gen!(compact, result, ci, init_key, key, world, edges) : nothing
144144
odef = @insert_node_here compact line make_odefunction(new_oc, mass_matrix, initf)::ODEFunction true
145145

146146
odef_and_n = @insert_node_here compact line tuple(odef, nd + na)::Tuple true

src/transform/codegen/rhs.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ function compute_slot_ranges(info::MappingInfo, callee_key, var_assignment, eq_a
6363
return state_ranges
6464
end
6565

66-
function rhs_finish!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, ordinal::Int, indexT=Int)
66+
function rhs_finish!(result::DAEIPOResult, ci::CodeInstance, key::TornCacheKey, world::UInt, ordinal::Int, edges::SimpleVector, indexT=Int)
6767
structure = make_structure_from_ipo(result)
6868
tstate = TransformationState(result, structure, copy(result.total_incidence))
69-
return rhs_finish!(tstate, ci, key, world, ordinal, indexT)
69+
return rhs_finish!(tstate, ci, key, world, ordinal, edges, indexT)
7070
end
7171

7272
function rhs_finish!(
@@ -75,6 +75,7 @@ function rhs_finish!(
7575
key::TornCacheKey,
7676
world::UInt,
7777
ordinal::Int,
78+
edges::SimpleVector,
7879
indexT=Int)
7980

8081
(; result, structure) = state
@@ -145,8 +146,8 @@ function rhs_finish!(
145146
spec_data = stmt.args[1]
146147
callee_key = spec_data[2]
147148
callee_ordinal = spec_data[end]::Int
148-
callee_result = structural_analysis!(callee_ci, world)
149-
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal)
149+
callee_result = structural_analysis!(callee_ci, world, edges)
150+
callee_daef_ci = rhs_finish!(callee_result, callee_ci, callee_key, world, callee_ordinal, edges)
150151
# Allocate a continuous block of variables for all callee alg and diff states
151152

152153
empty!(stmt.args)
@@ -218,7 +219,7 @@ function rhs_finish!(
218219
src = ir_to_src(ir)
219220

220221
abi = Tuple{Tuple, Tuple, (VectorViewType for _ in arg_range)..., Float64}
221-
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, RHSSpec(key, ir_ordinal))
222+
daef_ci = cache_dae_ci!(ci, src, src.debuginfo, abi, RHSSpec(key, ir_ordinal), edges)
222223
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), daef_ci, src)
223224

224225
push!(cis, daef_ci)

src/transform/common.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,11 @@ function ir_to_src(ir::IRCode)
6363
return src
6464
end
6565

66-
function cache_dae_ci!(old_ci, src, debuginfo, abi, owner)
66+
function cache_dae_ci!(old_ci, src, debuginfo, abi, owner, edges::SimpleVector)
6767
daef_ci = CodeInstance(abi === nothing ? old_ci.def : Core.ABIOverride(abi, old_ci.def), owner, Tuple, Union{}, nothing, src, Int32(0),
6868
UInt(1)#=ci.min_world=#, old_ci.max_world, old_ci.ipo_purity_bits,
69-
nothing, debuginfo, Compiler.empty_edges)
69+
nothing, debuginfo, edges)
70+
Compiler.store_backedges(daef_ci, edges)
7071
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), old_ci.def, daef_ci)
7172
return daef_ci
7273
end

0 commit comments

Comments
 (0)