Skip to content

Commit 54c02c1

Browse files
authored
Generalize symbol type for debug scopes (#8)
* Generalize symbol type for debug scopes * More scope adjustment * Adjust to JuliaLang/julia#49260 * More Core.Compiler adjustments
1 parent 20d1ed4 commit 54c02c1

16 files changed

+703
-274
lines changed

Manifest.toml

Lines changed: 542 additions & 140 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Project.toml

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ uuid = "32805668-c3d0-42c2-aafd-0d0a9857a104"
33
version = "1.21.0"
44
authors = ["JuliaHub, Inc. and other contributors"]
55

6+
[workspace]
7+
projects = ["test"]
8+
69
[deps]
710
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
811
CentralizedCaches = "d1073d05-2d26-4019-b855-dfa0385fef5e"
@@ -30,12 +33,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3033
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
3134
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3235
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
33-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3436
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3537
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"
38+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3639

3740
[sources]
38-
ModelingToolkitStandardLibrary = {rev = "ox/dae_compatible5", url = "https://github.com/CedarEDA/ModelingToolkitStandardLibrary.jl"}
3941
SciMLBase = {rev = "os/dae-get-du2", url = "https://github.com/CedarEDA/SciMLBase.jl"}
4042
SciMLSensitivity = {rev = "kf/mindep2", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}
4143

@@ -54,12 +56,10 @@ Cthulhu = "2.10.1"
5456
DiffEqBase = "6.149.2"
5557
Diffractor = "0.2.7"
5658
ForwardDiff = "0.10.36"
57-
ModelingToolkitStandardLibrary = "2.6.0"
5859
NonlinearSolve = "3.5.0"
5960
OrderedCollections = "1.6.3"
6061
PrecompileTools = "1"
6162
Preferences = "1.4"
62-
Roots = "2.0.22"
6363
SciMLBase = "2.24.0"
6464
SciMLSensitivity = "7.47"
6565
StateSelection = "0.2.0"
@@ -68,24 +68,5 @@ Sundials = "4.19"
6868
SymbolicIndexingInterface = "0.3"
6969
julia = "1.11"
7070

71-
[extras]
72-
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
73-
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
74-
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
75-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
76-
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
77-
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
78-
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
79-
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
80-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
81-
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
82-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
83-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
84-
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
85-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
86-
8771
[preferences.LinearSolve]
8872
LoadMKL_JLL = false
89-
90-
[targets]
91-
test = ["ControlSystemsBase", "DataInterpolations", "FiniteDiff", "FiniteDifferences", "IfElse", "InteractiveUtils", "ModelingToolkit", "ModelingToolkitStandardLibrary", "OrdinaryDiffEq", "SafeTestsets", "Sundials", "Test", "Roots", "StaticArrays"]

ext/DAECompilerModelingToolkitExt.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ function declare_parameters(model, struct_name)
7575
backing::B
7676
end
7777
)
78-
79-
78+
79+
8080
constructor_expr =:(
8181
@generated function _check_parameter_names(::Type{$struct_name}, param_kwargs::NamedTuple)
8282
unexpected_parameters = setdiff(fieldnames(param_kwargs), $param_names_tuple_expr)
@@ -108,7 +108,7 @@ function declare_parameters(model, struct_name)
108108
if name === $param_name
109109
return if hasfield(B, $param_name)
110110
getfield(getfield(this, :backing), $param_name)
111-
else
111+
else
112112
$param_value
113113
end
114114
end
@@ -118,7 +118,7 @@ function declare_parameters(model, struct_name)
118118
return getfield(getfield(this, :backing), name)
119119
))
120120
getproperty_expr.args[end].args[end] = Expr(:block, getproperty_body...)
121-
121+
122122
return Expr(:block, struct_expr, constructor_expr, propertynames_expr, getproperty_expr)
123123
end
124124

@@ -206,7 +206,7 @@ end
206206

207207
macro DAECompiler.declare_MTKConnector(mtk_component, ports...)
208208
# We do need to do run time eval, because we can't decide what to construct with just lexical information.
209-
# we need the values of the
209+
# we need the values of the
210210
:(Base.eval(@__MODULE__, $MTKConnector_AST($(esc(mtk_component)), $(esc.(ports)...))))
211211
end
212212

@@ -219,7 +219,7 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
219219
end
220220

221221
while !isnothing(MTK.get_parent(model))
222-
# Undo any call to structural_simplify
222+
# Undo any call to structural_simplify
223223
# (Should we give a warning here? They did waste CPU cycles simplfying it in first place)
224224
model = MTK.get_parent(model)
225225
end
@@ -239,11 +239,11 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
239239

240240

241241
struct_name = gensym(nameof(model))
242-
242+
243243
return quote
244244
$(declare_parameters(model, struct_name))
245245

246-
function (this::$struct_name)($(port_names...); dscope=$(_c(Scope))())
246+
function (this::$struct_name)($(map(port->:($(port)::Float64), port_names)...); dscope=$(_c(Scope))())
247247
$(declare_vars(model, :dscope))
248248
$(declare_derivatives(state))
249249
$(declare_equations(state, model, :dscope, ports))
@@ -258,4 +258,4 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
258258
end
259259

260260

261-
end # module
261+
end # module

ext/DAECompilerSciMLSensitivityExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ with one column per time step in `ts` and one one row per `variable`/`observed!`
3939
"""
4040
function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolution, syms::Vector{<:DAECompiler.ScopeRef}, ts=sol.t)
4141
us, du_dparams = extract_local_sensitivities(sol, ts)
42-
var_inds, obs_inds = DAECompiler.split_and_sort_syms(syms)
43-
4442
transformed_sys = DAECompiler.get_transformed_sys(sol)
43+
sys = DAECompiler.get_sys(transformed_sys)
44+
var_inds, obs_inds = DAECompiler.split_and_sort_syms(sys, syms)
45+
4546
dreconstruct! = get!(sol.prob.f.observed.derivative_cache, (var_inds, obs_inds, false)) do
4647
DAECompiler.compile_batched_reconstruct_derivatives(transformed_sys, var_inds, obs_inds, false, false;)
4748
end
48-
49+
4950
num_params = length(du_dparams)
5051
dout_vars_per_param = [similar(us, (length(var_inds), length(ts))) for _ in 1:num_params]
5152
dout_obs_per_param = [similar(us, (length(obs_inds), length(ts))) for _ in 1:num_params]
@@ -67,7 +68,7 @@ function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolutio
6768
end
6869

6970
return map(dout_vars_per_param, dout_obs_per_param) do dout_vars, dout_obs
70-
DAECompiler.join_syms(syms, dout_vars, dout_obs, (var_inds, obs_inds))
71+
DAECompiler.join_syms(sys, syms, dout_vars, dout_obs, (var_inds, obs_inds))
7172
end
7273
end
7374

src/analysis/compiler.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ has_any_genscope(sc::PartialStruct) = false # TODO
319319

320320
function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
321321
if isa(argt, Const)
322-
@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
322+
#@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
323323
return argt
324324
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
325325
return PartialScope(add_scope!(which))
@@ -362,7 +362,7 @@ function make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_va
362362
end
363363

364364
function resolve_genscopes(names)
365-
new_names = OrderedDict{LevelKey, NameLevel}()
365+
new_names = OrderedDict{Any, NameLevel}()
366366
for (key, val) in collect(names)
367367
if val.children !== nothing
368368
@reset val.children = resolve_genscopes(val.children)
@@ -423,7 +423,7 @@ Perform the structural analysis on optimized code of `mi` and return `structure:
423423
end
424424
end
425425

426-
function refresh_identities(names::OrderedDict{LevelKey, NameLevel})
426+
function refresh_identities(names::OrderedDict{LevelKey, NameLevel}) where {LevelKey, NameLevel}
427427
new_names = OrderedDict{LevelKey, NameLevel}()
428428
for (key, val) in names
429429
if isa(key, Gen)
@@ -502,7 +502,7 @@ end
502502
eq_kind = VarEqKind[]
503503
warnings = UnsupportedIRException[]
504504

505-
names = OrderedDict{LevelKey, NameLevel}()
505+
names = OrderedDict{Any, NameLevel}()
506506

507507
nsysmscopes = 0
508508
ncallees = 0
@@ -1191,7 +1191,7 @@ function process_ipo_return!(ultimate_rt::PartialStruct, args...)
11911191
return PartialStruct(ultimate_rt.typ, fields), nimplicitoutpairs
11921192
end
11931193

1194-
function get_variable_name(names::OrderedDict{LevelKey, NameLevel}, var_to_diff, var_idx)
1194+
function get_variable_name(names::OrderedDict, var_to_diff, var_idx)
11951195
var_names = build_var_names(names, var_to_diff)
11961196
return var_names[var_idx]
11971197
end
@@ -1221,7 +1221,7 @@ function get_inline_backtrace(ir::IRCode, v::SSAValue)
12211221
return frames
12221222
end
12231223

1224-
function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector{<:LevelKey})
1224+
function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector) where {LevelKey, NameLevel}
12251225
for i = length(stack):-1:2
12261226
s = stack[i]
12271227
if !haskey(names, s)
@@ -1235,11 +1235,11 @@ end
12351235
is_valid_partial_scope(_) = false
12361236
is_valid_partial_scope(ps::PartialScope) = true
12371237
function is_valid_partial_scope(ps::PartialStruct)
1238-
if ps.typ === Scope
1238+
if ps.typ <: Scope
12391239
isa(ps.fields[2], Const) || return false
12401240
isa(ps.fields[2].val, Symbol) || return false
12411241
return is_valid_partial_scope(ps.fields[1])
1242-
elseif ps.typ === GenScope
1242+
elseif ps.typ <: GenScope
12431243
isa(ps.fields[1], Const) || return false
12441244
return is_valid_partial_scope(ps.fields[2])
12451245
else
@@ -1248,11 +1248,11 @@ function is_valid_partial_scope(ps::PartialStruct)
12481248
end
12491249

12501250
function sym_stack(ps::PartialStruct)
1251-
if ps.typ === Scope
1251+
if ps.typ <: Scope
12521252
sym = (ps.fields[2]::Const).val::Symbol
12531253
return pushfirst!(sym_stack(ps.fields[1]), sym)
12541254
else
1255-
@assert ps.typ === GenScope
1255+
@assert ps.typ <: GenScope
12561256
stack = sym_stack(ps.fields[2])
12571257
scope_identity = ((ps.fields[1]::Const).val)::Intrinsics.ScopeIdentity
12581258
stack[1] = Gen(scope_identity, stack[1])
@@ -1261,7 +1261,7 @@ function sym_stack(ps::PartialStruct)
12611261
end
12621262

12631263
sym_stack(ps::PartialScope) = LevelKey[ps]
1264-
function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
1264+
function record_scope!(ir::IRCode, names::OrderedDict, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
12651265
varssa::Vector, idx::Int, lens)
12661266

12671267
stack = sym_stack(scope)
@@ -1282,11 +1282,15 @@ function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scop
12821282
end
12831283

12841284
function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, val::NameLevel,
1285-
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
1285+
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}
12861286

12871287
haskey(names, key) || (names[key] = NameLevel())
12881288
existing = names[key]
1289-
for (offset, lens) in ((x->(only(findnz(mapping.var_coeffs[x].row)[1])), @o _.var),
1289+
function remap_var(x)
1290+
r = only(findnz(mapping.var_coeffs[x].row)[1]) - 1
1291+
return r
1292+
end
1293+
for (offset, lens) in ((remap_var, @o _.var),
12901294
(x->(x+obsoffset), @o _.obs),
12911295
(x->mapping.eqs[x], @o _.eq), (x->(x+epsoffset), @o _.eps))
12921296
if lens(val) !== nothing
@@ -1312,7 +1316,7 @@ function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, v
13121316
end
13131317

13141318
function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::Union{Scope, PartialStruct}, val::NameLevel,
1315-
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
1319+
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}
13161320

13171321
stack = sym_stack(key)
13181322
if isempty(stack)

src/analysis/debugging.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
using StateSelection
22
using StateSelection.BipartiteGraphs
33

4-
function build_var_names(names::OrderedDict{LevelKey, NameLevel}, var_to_diff)
4+
function build_var_names(names::OrderedDict, var_to_diff)
55
var_names = OrderedDict{Int,String}()
66
build_var_names!(var_names, names, var_to_diff)
77
return var_names
88
end
9-
function build_var_names!(var_names, names::OrderedDict{LevelKey, NameLevel}, var_to_diff, prefix=String[])
9+
function build_var_names!(var_names, names::OrderedDict, var_to_diff, prefix=String[])
1010
for name in keys(names)
1111
name_path = join(vcat(prefix..., name), ".")
1212
level = names[name]

src/analysis/interpreter.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,13 @@ end
153153
Diffractor.disable_forward(interp::DAEInterpreter) = CC.NativeInterpreter(interp.world)
154154

155155
function CC.InferenceParams(::DAEInterpreter)
156-
return CC.InferenceParams(;
157-
unoptimize_throw_blocks=false,
158-
assume_bindings_static=true,
159-
ignore_recursion_hardlimit=true)
156+
args = (;
157+
assume_bindings_static=true,
158+
ignore_recursion_hardlimit=true)
159+
if VERSION < v"1.12.0-DEV.1017"
160+
args = (; unoptimize_throw_blocks=false, args...)
161+
end
162+
return CC.InferenceParams(; args...)
160163
end
161164
function CC.OptimizationParams(::DAEInterpreter)
162165
opt_params = CC.OptimizationParams(;
@@ -680,7 +683,7 @@ function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, t
680683
eq_mapping[idnum(template)] = idnum(arg)
681684
elseif CC.is_const_argtype(template)
682685
#@Core.Compiler.show (arg, template)
683-
@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
686+
#@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
684687
elseif isa(template, PartialScope)
685688
id = idnum(template)
686689
(id > length(applied_scopes)) && resize!(applied_scopes, id)
@@ -919,7 +922,7 @@ function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instr
919922
argtypes = CC.collect_argtypes(interp, stmt.args, nothing, irsv)[2:end]
920923
callee_result = dae_result_for_inst(interp, inst)
921924
callee_result === nothing && return RT(nothing, (false, false))
922-
if isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
925+
if isa(callee_result, UncompilableIPOResult) || isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
923926
return RT(nothing, (false, false))
924927
end
925928
mapping = CalleeMapping(CC.optimizer_lattice(interp), argtypes, callee_result)
@@ -1030,14 +1033,11 @@ end
10301033
# -----
10311034

10321035
function typeinf_dae(@nospecialize(tt), world::UInt=get_world_counter();
1033-
method_table::Union{Nothing,MethodTable} = nothing,
10341036
ipo_analysis_mode::Bool = false)
1035-
interp = DAEInterpreter(world; method_table, ipo_analysis_mode)
1036-
match = Base._which(tt;
1037-
method_table=CC.method_table(interp),
1038-
world=get_inference_world(interp),
1039-
raise=false)
1040-
match === nothing && single_match_error(tt)
1037+
interp = DAEInterpreter(world; ipo_analysis_mode)
1038+
match = Base._methods_by_ftype(tt, 1, world)
1039+
isempty(match) && single_match_error(tt)
1040+
match = only(match)
10411041
mi = CC.specialize_method(match)
10421042
ci = CC.typeinf_ext(interp, mi, Core.Compiler.SOURCE_MODE_ABI)
10431043
return interp, ci

src/analysis/lattice.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ CC.widenconst(::PartialScope) = Scope
347347
CC.widenconst(pkv::PartialKeyValue) = widenconst(pkv.typ)
348348
CC.:(inc::Incidence, inc2) = CC.:(inc2, Float64) && !isa(inc2, Const)
349349

350-
function CC._uniontypes(x::Incidence, ts)
350+
function CC._uniontypes(x::Incidence, ts::Vector{Any})
351351
u = x.typ
352352
if isa(u, Union)
353353
CC.push!(ts, is_non_incidence_type(u.a) ? u.a : Incidence(u.a, x.row, x.eps))
@@ -462,6 +462,9 @@ function CC._getfield_tfunc(🥬::DAELattice, @nospecialize(s00), @nospecialize(
462462
return Union{}
463463
end
464464
rt = CC._getfield_tfunc(CC.widenlattice(🥬), s00.typ, name, setfield)
465+
if rt == Union{}
466+
return Union{}
467+
end
465468
if isempty(s00)
466469
return Incidence(rt)
467470
end

src/cache.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ struct NameLevel
2424
obs::Union{Nothing, Int}
2525
eq::Union{Nothing, Int}
2626
eps::Union{Nothing, Int}
27-
children::Union{Nothing, OrderedDict{LevelKey, NameLevel}}
27+
# TODO: This should be an OrderedIdDict
28+
children::Union{Nothing, OrderedDict{Any, NameLevel}}
2829
end
2930
NameLevel() =
3031
NameLevel(nothing, nothing, nothing, nothing, nothing)
31-
NameLevel(children::OrderedDict{LevelKey, NameLevel}) =
32+
NameLevel(children::OrderedDict{Any, NameLevel}) =
3233
NameLevel(nothing, nothing, nothing, nothing, children)
3334

3435
struct UnsupportedIRException <: Exception
@@ -77,7 +78,7 @@ struct DAEIPOResult
7778
total_incidence::Vector{Any}
7879
eq_kind::Vector{VarEqKind}
7980
eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{SSAValue, Int}}}}
80-
names::OrderedDict{LevelKey, NameLevel}
81+
names::OrderedDict{Any, NameLevel} # TODO: OrderedIdDict
8182
nobserved::Int
8283
neps::Int
8384
ic_nzc::Int

0 commit comments

Comments
 (0)