Skip to content

Commit a9d17dd

Browse files
committed
Rewrite and support flattening, improve external equation support
1 parent 57f65d7 commit a9d17dd

File tree

11 files changed

+367
-105
lines changed

11 files changed

+367
-105
lines changed

src/DAECompiler.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module DAECompiler
66
using Diffractor
77
using OrderedCollections
88
using Compiler
9-
using Compiler: IRCode, IncrementalCompact, DebugInfoStream, NewInstruction, argextype, singleton_type, isexpr, widenconst
9+
using Compiler: AbstractLattice, IRCode, IncrementalCompact, DebugInfoStream, NewInstruction, argextype, singleton_type, isexpr, widenconst
1010
using Core.IR
1111
using SciMLBase
1212
using AutoHashEquals
@@ -21,11 +21,11 @@ module DAECompiler
2121
include("analysis/lattice.jl")
2222
include("analysis/ADAnalyzer.jl")
2323
include("analysis/scopes.jl")
24+
include("analysis/flattening.jl")
2425
include("analysis/cache.jl")
2526
include("analysis/refiner.jl")
2627
include("analysis/ipoincidence.jl")
2728
include("analysis/structural.jl")
28-
include("analysis/flattening.jl")
2929
include("transform/state_selection.jl")
3030
include("transform/common.jl")
3131
include("transform/runtime.jl")

src/analysis/ADAnalyzer.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ end
6262

6363
struct AnalyzedSource
6464
ir::Compiler.IRCode
65+
slotnames::Vector{Any}
6566
inline_cost::Compiler.InlineCostType
6667
nargs::UInt
6768
isva::Bool
@@ -72,19 +73,26 @@ end
7273
Core.svec(edges..., interp.edges...)
7374
end
7475

76+
function get_slotnames(def::Method)
77+
names = split(def.slot_syms, '\0')
78+
return map(Symbol, names)
79+
end
80+
7581
@override function Compiler.transform_result_for_cache(interp::ADAnalyzer, result::InferenceResult, edges::SimpleVector)
7682
ir = result.src.optresult.ir
83+
slotnames = get_slotnames(result.linfo.def)
7784
params = Compiler.OptimizationParams(interp)
78-
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
85+
return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
7986
end
8087

8188
@override function Compiler.transform_result_for_local_cache(interp::ADAnalyzer, result::InferenceResult)
8289
if Compiler.result_is_constabi(interp, result)
8390
return nothing
8491
end
8592
ir = result.src.optresult.ir
93+
slotnames = get_slotnames(result.linfo.def)
8694
params = Compiler.OptimizationParams(interp)
87-
return AnalyzedSource(ir, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
95+
return AnalyzedSource(ir, slotnames, Compiler.compute_inlining_cost(interp, result), result.src.src.nargs, result.src.src.isva)
8896
end
8997

9098
function Compiler.retrieve_ir_for_inlining(ci::CodeInstance, result::AnalyzedSource)

src/analysis/cache.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct DAEIPOResult
3535
opaque_eligible::Bool
3636
extended_rt::Any
3737
argtypes
38+
argmap::ArgumentMap
3839
nexternalargvars::Int # total vars is length(var_to_diff)
3940
nsysmscopes::Int
4041
nexternaleqs::Int

src/analysis/flattening.jl

Lines changed: 171 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,129 @@
1+
const CompositeIndex = Vector{Int}
2+
3+
struct ArgumentMap
4+
variables::Vector{CompositeIndex} # index into argument tuple type
5+
equations::Vector{CompositeIndex} # index into argument tuple type
6+
end
7+
ArgumentMap() = ArgumentMap(CompositeIndex[], CompositeIndex[])
8+
9+
function ArgumentMap(argtypes::Vector{Any})
10+
map = ArgumentMap()
11+
index = CompositeIndex()
12+
fill_argument_map!(map, index, argtypes)
13+
return map
14+
end
15+
16+
function fill_argument_map!(map::ArgumentMap, index::CompositeIndex, types::Vector{Any})
17+
for (i, type) in enumerate(types)
18+
push!(index, i)
19+
fill_argument_map!(map, index, type)
20+
pop!(index)
21+
end
22+
end
23+
24+
function fill_argument_map!(map::ArgumentMap, index::CompositeIndex, @nospecialize(type))
25+
if isprimitivetype(type) || isa(type, Incidence)
26+
push!(map.variables, copy(index))
27+
elseif type === equation
28+
push!(map.equations, copy(index))
29+
elseif isa(type, PartialStruct) || isstructtype(type)
30+
fields = isa(type, PartialStruct) ? type.fields : collect(Any, fieldtypes(type))
31+
fill_argument_map!(map, index, fields)
32+
end
33+
end
34+
35+
struct FlatteningState
36+
compact::IncrementalCompact
37+
settings::Settings
38+
map::ArgumentMap
39+
nvariables::Int
40+
nequations::Int
41+
new_argtypes::Vector{Any}
42+
end
43+
44+
function FlatteningState(compact::IncrementalCompact, settings::Settings, map::ArgumentMap)
45+
FlatteningState(compact, settings, deepcopy(map), length(map.variables), length(map.equations), Any[])
46+
end
47+
48+
function next_variable!(state::FlatteningState)
49+
popfirst!(state.map.variables)
50+
return state.nvariables - length(state.map.variables)
51+
end
52+
53+
function next_equation!(state::FlatteningState)
54+
popfirst!(state.map.equations)
55+
return state.nequations - length(state.map.equations)
56+
end
57+
58+
function flatten_arguments!(state::FlatteningState, argtypes::Vector{Any})
59+
args = Any[]
60+
# push!(state.new_argtypes, argtypes[1])
61+
for argt in argtypes
62+
arg = flatten_argument!(state, argt)
63+
arg === nothing && return nothing
64+
push!(args, arg)
65+
end
66+
@assert isempty(state.map.variables)
67+
@assert isempty(state.map.equations)
68+
return args
69+
end
70+
71+
function flatten_argument!(state::FlatteningState, @nospecialize(argt))
72+
@assert !isa(argt, Incidence) && !isa(argt, Eq)
73+
(; compact, settings) = state
74+
if isa(argt, Const)
75+
return argt.val
76+
elseif Base.issingletontype(argt)
77+
return argt.instance
78+
elseif isprimitivetype(argt)
79+
push!(state.new_argtypes, argt)
80+
return Argument(next_variable!(state))
81+
elseif argt === equation
82+
eq = next_equation!(state)
83+
line = compact[Compiler.OldSSAValue(1)][:line]
84+
ssa = @insert_instruction_here(compact, line, settings, (:invoke)(nothing, InternalIntrinsics.external_equation)::Eq(eq))
85+
return ssa
86+
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
87+
line = compact[Compiler.OldSSAValue(1)][:line]
88+
ssa = @insert_instruction_here(compact, line, settings, error("Cannot IPO model arg type $argt")::Union{})
89+
return nothing
90+
else
91+
if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing
92+
line = compact[Compiler.OldSSAValue(1)][:line]
93+
ssa = @insert_instruction_here(compact, line, settings, error("Cannot IPO model arg type $argt")::Union{})
94+
return nothing
95+
end
96+
fields = isa(argt, PartialStruct) ? argt.fields : collect(Any, fieldtypes(argt))
97+
args = flatten_arguments!(state, fields)
98+
args === nothing && return nothing
99+
this = Expr(:new, isa(argt, PartialStruct) ? argt.typ : argt, args...)
100+
line = compact[Compiler.OldSSAValue(1)][:line]
101+
ssa = @insert_instruction_here(compact, line, settings, this::argt)
102+
return ssa
103+
end
104+
end
105+
106+
function flatten_arguments_for_callee!(compact::IncrementalCompact, map::ArgumentMap, argtypes, 𝕃, line, settings)
107+
list = Any[]
108+
this = nothing
109+
last_index = Int[]
110+
for index in map.variables
111+
from = findfirst(j -> get(last_index, j, -1) !== index[j], eachindex(index))::Int
112+
for i in from:length(index)
113+
field = index[i]
114+
if i == 1
115+
this = Argument(2 + field)
116+
else
117+
thistype = argextype(this, compact)
118+
fieldtype = Compiler.getfield_tfunc(𝕃, Const(field))
119+
this = @insert_instruction_here(compact, line, settings, getfield(this, field)::fieldtype)
120+
end
121+
end
122+
push!(list, this)
123+
end
124+
return list
125+
end
126+
1127
function _flatten_parameter!(𝕃, compact, argtypes, ntharg, line, settings)
2128
list = Any[]
3129
for (argn, argt) in enumerate(argtypes)
@@ -67,51 +193,55 @@ function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, t
67193
return Pair{Int, Int}(offset, eqoffset)
68194
end
69195

70-
struct TransformedArg
71-
ssa::Any
72-
offset::Int
73-
eqoffset::Int
74-
TransformedArg(@nospecialize(arg), new_offset::Int, new_eqoffset::Int) = new(arg, new_offset, new_eqoffset)
75-
end
76196

77-
function flatten_argument!(compact::Compiler.IncrementalCompact, settings::Settings, @nospecialize(argt), offset::Int, eqoffset::Int, argtypes::Vector{Any})::TransformedArg
78-
@assert !isa(argt, Incidence) && !isa(argt, Eq)
79-
if isa(argt, Const)
80-
return TransformedArg(argt.val, offset, eqoffset)
81-
elseif Base.issingletontype(argt)
82-
return TransformedArg(argt.instance, offset, eqoffset)
83-
elseif Base.isprimitivetype(argt)
84-
push!(argtypes, argt)
85-
return TransformedArg(Argument(offset+1), offset+1, eqoffset)
86-
elseif argt === equation
87-
line = compact[Compiler.OldSSAValue(1)][:line]
88-
ssa = @insert_instruction_here(compact, line, settings, (:invoke)(nothing, InternalIntrinsics.external_equation)::Eq(eqoffset+1))
89-
return TransformedArg(ssa, offset, eqoffset+1)
90-
elseif isabstracttype(argt) || ismutabletype(argt) || (!isa(argt, DataType) && !isa(argt, PartialStruct))
91-
line = compact[Compiler.OldSSAValue(1)][:line]
92-
ssa = @insert_instruction_here(compact, line, settings, error("Cannot IPO model arg type $argt")::Union{})
93-
return TransformedArg(ssa, -1, eqoffset)
94-
else
95-
if !isa(argt, PartialStruct) && Base.datatype_fieldcount(argt) === nothing
96-
line = compact[Compiler.OldSSAValue(1)][:line]
97-
ssa = @insert_instruction_here(compact, line, settings, error("Cannot IPO model arg type $argt")::Union{})
98-
return TransformedArg(ssa, -1, eqoffset)
197+
remove_variable_and_equation_annotations(argtypes) = Any[widenconst(T) for T in argtypes]
198+
199+
function annotate_variables_and_equations(argtypes::Vector{Any}, map::ArgumentMap)
200+
argtypes_annotated = Any[]
201+
pstructs = Dict{CompositeIndex,PartialStruct}()
202+
for (i, arg) in enumerate(argtypes)
203+
if arg !== equation && arg !== Incidence && isstructtype(arg) && (any(==(i) first, map.variables) || any(==(i) first, map.equations))
204+
arg = init_partialstruct(arg)
205+
pstructs[[i]] = arg
99206
end
100-
(args, _, offset) = flatten_arguments!(compact, settings, isa(argt, PartialStruct) ? argt.fields : collect(Any, fieldtypes(argt)), offset, eqoffset, argtypes)
101-
offset == -1 && return TransformedArg(ssa, -1, eqoffset)
102-
this = Expr(:new, isa(argt, PartialStruct) ? argt.typ : argt, args...)
103-
line = compact[Compiler.OldSSAValue(1)][:line]
104-
ssa = @insert_instruction_here(compact, line, settings, this::argt)
105-
return TransformedArg(ssa, offset, eqoffset)
207+
push!(argtypes_annotated, arg)
106208
end
209+
210+
function fields_for_index(index)
211+
length(index) > 1 || return argtypes_annotated
212+
# Find the parent `PartialStruct` that holds the variable field,
213+
# creating any further `PartialStruct` going down if necessary.
214+
i, base = find_base(pstructs, index)
215+
local fields = base.fields
216+
for j in @view index[(i + 1):(end - 1)]
217+
pstruct = init_partialstruct(fields[j])
218+
fields[j] = pstruct
219+
fields = pstruct.fields
220+
end
221+
return fields
222+
end
223+
224+
# Populate `PartialStruct` variable fields with an `Incidence` lattice element.
225+
for (variable, index) in enumerate(map.variables)
226+
fields = fields_for_index(index)
227+
type = get_fieldtype(argtypes, index)
228+
fields[index[end]] = Incidence(type, variable)
229+
end
230+
231+
# Do the same for equations with an `Eq` lattice element.
232+
for (equation, index) in enumerate(map.equations)
233+
fields = fields_for_index(index)
234+
fields[index[end]] = Eq(equation)
235+
end
236+
237+
return argtypes_annotated
107238
end
108239

109-
function flatten_arguments!(compact::Compiler.IncrementalCompact, settings::Settings, argtypes::Vector{Any}, offset::Int=0, eqoffset::Int=0, new_argtypes::Vector{Any} = Any[])
110-
args = Any[]
111-
for argt in argtypes
112-
(; ssa, offset, eqoffset) = flatten_argument!(compact, settings, argt, offset, eqoffset, new_argtypes)
113-
offset == -1 && break
114-
push!(args, ssa)
240+
init_partialstruct(@nospecialize(T)) = PartialStruct(T, collect(Any, fieldtypes(T)))
241+
242+
function find_base(dict::Dict{CompositeIndex}, index::CompositeIndex)
243+
for i in reverse(eachindex(index))
244+
base = get(dict, @view(index[1:i]), nothing)
245+
base !== nothing && return i, base
115246
end
116-
return (args, new_argtypes, offset, eqoffset)
117247
end

src/analysis/ipoincidence.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,44 @@ function apply_linear_incidence(𝕃, ret::PartialStruct, caller::CallerMappingS
155155
return PartialStruct(𝕃, ret.typ, Any[apply_linear_incidence(𝕃, f, caller, mapping) for f in ret.fields])
156156
end
157157

158-
function CalleeMapping(𝕃::Compiler.AbstractLattice, argtypes::Vector{Any}, callee_ci::CodeInstance, callee_result::DAEIPOResult, template_argtypes)
158+
function CalleeMapping(𝕃::AbstractLattice, argtypes::Vector{Any}, callee_ci::CodeInstance, callee_result::DAEIPOResult)
159+
caller_argtypes = Compiler.va_process_argtypes(𝕃, argtypes, callee_ci.inferred.nargs, callee_ci.inferred.isva)
160+
callee_argtypes = callee_ci.inferred.ir.argtypes
161+
argmap = ArgumentMap(callee_argtypes)
162+
nvars = length(callee_result.var_to_diff)
163+
neqs = length(callee_result.total_incidence)
164+
@assert length(argmap.variables) nvars
165+
@assert length(argmap.equations) neqs
166+
159167
applied_scopes = Any[]
160-
coeffs = Vector{Any}(undef, length(callee_result.var_to_diff))
161-
eq_mapping = fill(0, length(callee_result.total_incidence))
168+
coeffs = Vector{Any}(undef, nvars)
169+
eq_mapping = fill(0, neqs)
170+
mapping = CalleeMapping(coeffs, eq_mapping, applied_scopes)
162171

163-
va_argtypes = Compiler.va_process_argtypes(𝕃, argtypes, callee_ci.inferred.nargs, callee_ci.inferred.isva)
164-
process_template!(𝕃, coeffs, eq_mapping, applied_scopes, va_argtypes, template_argtypes)
172+
fill_callee_mapping!(mapping, argmap, caller_argtypes, 𝕃)
173+
return mapping
174+
end
165175

166-
return CalleeMapping(coeffs, eq_mapping, applied_scopes)
176+
function fill_callee_mapping!(mapping::CalleeMapping, argmap::ArgumentMap, argtypes::Vector{Any}, 𝕃::AbstractLattice)
177+
for (i, index) in enumerate(argmap.variables)
178+
type = get_fieldtype(argtypes, index, 𝕃)
179+
mapping.var_coeffs[i] = type
180+
end
181+
for (i, index) in enumerate(argmap.equations)
182+
eq = get_fieldtype(argtypes, index, 𝕃)::Eq
183+
mapping.eqs[i] = eq.id
184+
end
167185
end
168186

187+
function get_fieldtype(argtypes::Vector{Any}, index::CompositeIndex, 𝕃::AbstractLattice = Compiler.fallback_lattice)
188+
@assert !isempty(index)
189+
index = copy(index)
190+
type = argtypes[popfirst!(index)]
191+
while !isempty(index)
192+
type = Compiler.getfield_tfunc(𝕃, type, Const(popfirst!(index)))
193+
end
194+
return type
195+
end
169196

170197
struct MappingInfo <: Compiler.CallInfo
171198
info::Any

src/analysis/lattice.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using SparseArrays
44

55
########################## EqStructureLattice ####################################
66
"""
7-
struct EqStructureLattice <: Compiler.AbstractLattice
7+
struct EqStructureLattice <: AbstractLattice
88
99
This lattice implements the `AbstractLattice` interface. It adjoins `Incidence` and `Eq`.
1010
@@ -34,7 +34,7 @@ the taint of %phi depends not only on `%a` and `%b`, but also on the taint of
3434
the branch condition `%cond`. This is a common feature of taint analysis, but
3535
is somewhat unusual from the perspective of other Julia type lattices.
3636
"""
37-
struct EqStructureLattice <: Compiler.AbstractLattice; end
37+
struct EqStructureLattice <: AbstractLattice; end
3838
Compiler.widenlattice(::EqStructureLattice) = Compiler.ConstsLattice()
3939
Compiler.is_valid_lattice_norec(::EqStructureLattice, @nospecialize(v)) = isa(v, Incidence) || isa(v, Eq) || isa(v, PartialScope) || isa(v, PartialKeyValue)
4040
Compiler.has_extended_unionsplit(::EqStructureLattice) = true
@@ -537,7 +537,7 @@ struct PartialKeyValue
537537
end
538538
PartialKeyValue(typ) = PartialKeyValue(typ, typ, IdDict{Any, Any}())
539539

540-
function getkeyvalue_tfunc(𝕃::Compiler.AbstractLattice,
540+
function getkeyvalue_tfunc(𝕃::AbstractLattice,
541541
@nospecialize(collection), @nospecialize(key))
542542
isa(key, Const) || return Tuple{Any}
543543
if haskey(collection.vals, key.val)

src/analysis/refiner.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Compiler.cache_owner(interp::StructuralRefiner) = StructureCache(interp.settings
6464
end
6565

6666
argtypes = Compiler.collect_argtypes(interp, stmt.args, Compiler.StatementState(nothing, false), irsv)[2:end]
67-
mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_codeinst, callee_result, callee_codeinst.inferred.ir.argtypes)
67+
mapping = CalleeMapping(Compiler.optimizer_lattice(interp), argtypes, callee_codeinst, callee_result)
6868
new_rt = apply_linear_incidence(Compiler.optimizer_lattice(interp), callee_result.extended_rt,
6969
CallerMappingState(callee_result, interp.var_to_diff, interp.varclassification, interp.varkinds, interp.eqclassification, interp.eqkinds), mapping)
7070

0 commit comments

Comments
 (0)