Skip to content

Commit 1714a6b

Browse files
committed
Generalize symbol type for debug scopes
1 parent 20d1ed4 commit 1714a6b

File tree

14 files changed

+663
-234
lines changed

14 files changed

+663
-234
lines changed

Manifest.toml

Lines changed: 542 additions & 140 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Project.toml

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ uuid = "32805668-c3d0-42c2-aafd-0d0a9857a104"
33
version = "1.21.0"
44
authors = ["JuliaHub, Inc. and other contributors"]
55

6+
[workspace]
7+
projects = ["test"]
8+
69
[deps]
710
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
811
CentralizedCaches = "d1073d05-2d26-4019-b855-dfa0385fef5e"
@@ -30,12 +33,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
3033
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
3134
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3235
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
33-
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3436
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3537
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"
38+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
3639

3740
[sources]
38-
ModelingToolkitStandardLibrary = {rev = "ox/dae_compatible5", url = "https://github.com/CedarEDA/ModelingToolkitStandardLibrary.jl"}
3941
SciMLBase = {rev = "os/dae-get-du2", url = "https://github.com/CedarEDA/SciMLBase.jl"}
4042
SciMLSensitivity = {rev = "kf/mindep2", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}
4143

@@ -54,12 +56,10 @@ Cthulhu = "2.10.1"
5456
DiffEqBase = "6.149.2"
5557
Diffractor = "0.2.7"
5658
ForwardDiff = "0.10.36"
57-
ModelingToolkitStandardLibrary = "2.6.0"
5859
NonlinearSolve = "3.5.0"
5960
OrderedCollections = "1.6.3"
6061
PrecompileTools = "1"
6162
Preferences = "1.4"
62-
Roots = "2.0.22"
6363
SciMLBase = "2.24.0"
6464
SciMLSensitivity = "7.47"
6565
StateSelection = "0.2.0"
@@ -68,24 +68,5 @@ Sundials = "4.19"
6868
SymbolicIndexingInterface = "0.3"
6969
julia = "1.11"
7070

71-
[extras]
72-
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
73-
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
74-
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
75-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
76-
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
77-
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
78-
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
79-
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
80-
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
81-
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
82-
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
83-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
84-
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
85-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
86-
8771
[preferences.LinearSolve]
8872
LoadMKL_JLL = false
89-
90-
[targets]
91-
test = ["ControlSystemsBase", "DataInterpolations", "FiniteDiff", "FiniteDifferences", "IfElse", "InteractiveUtils", "ModelingToolkit", "ModelingToolkitStandardLibrary", "OrdinaryDiffEq", "SafeTestsets", "Sundials", "Test", "Roots", "StaticArrays"]

ext/DAECompilerSciMLSensitivityExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ with one column per time step in `ts` and one one row per `variable`/`observed!`
3939
"""
4040
function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolution, syms::Vector{<:DAECompiler.ScopeRef}, ts=sol.t)
4141
us, du_dparams = extract_local_sensitivities(sol, ts)
42-
var_inds, obs_inds = DAECompiler.split_and_sort_syms(syms)
43-
4442
transformed_sys = DAECompiler.get_transformed_sys(sol)
43+
sys = DAECompiler.get_sys(transformed_sys)
44+
var_inds, obs_inds = DAECompiler.split_and_sort_syms(sys, syms)
45+
4546
dreconstruct! = get!(sol.prob.f.observed.derivative_cache, (var_inds, obs_inds, false)) do
4647
DAECompiler.compile_batched_reconstruct_derivatives(transformed_sys, var_inds, obs_inds, false, false;)
4748
end
48-
49+
4950
num_params = length(du_dparams)
5051
dout_vars_per_param = [similar(us, (length(var_inds), length(ts))) for _ in 1:num_params]
5152
dout_obs_per_param = [similar(us, (length(obs_inds), length(ts))) for _ in 1:num_params]
@@ -67,7 +68,7 @@ function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolutio
6768
end
6869

6970
return map(dout_vars_per_param, dout_obs_per_param) do dout_vars, dout_obs
70-
DAECompiler.join_syms(syms, dout_vars, dout_obs, (var_inds, obs_inds))
71+
DAECompiler.join_syms(sys, syms, dout_vars, dout_obs, (var_inds, obs_inds))
7172
end
7273
end
7374

src/analysis/compiler.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ function make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_va
362362
end
363363

364364
function resolve_genscopes(names)
365-
new_names = OrderedDict{LevelKey, NameLevel}()
365+
new_names = OrderedDict{Any, NameLevel}()
366366
for (key, val) in collect(names)
367367
if val.children !== nothing
368368
@reset val.children = resolve_genscopes(val.children)
@@ -423,7 +423,7 @@ Perform the structural analysis on optimized code of `mi` and return `structure:
423423
end
424424
end
425425

426-
function refresh_identities(names::OrderedDict{LevelKey, NameLevel})
426+
function refresh_identities(names::OrderedDict{LevelKey, NameLevel}) where {LevelKey, NameLevel}
427427
new_names = OrderedDict{LevelKey, NameLevel}()
428428
for (key, val) in names
429429
if isa(key, Gen)
@@ -502,7 +502,7 @@ end
502502
eq_kind = VarEqKind[]
503503
warnings = UnsupportedIRException[]
504504

505-
names = OrderedDict{LevelKey, NameLevel}()
505+
names = OrderedDict{Any, NameLevel}()
506506

507507
nsysmscopes = 0
508508
ncallees = 0
@@ -1191,7 +1191,7 @@ function process_ipo_return!(ultimate_rt::PartialStruct, args...)
11911191
return PartialStruct(ultimate_rt.typ, fields), nimplicitoutpairs
11921192
end
11931193

1194-
function get_variable_name(names::OrderedDict{LevelKey, NameLevel}, var_to_diff, var_idx)
1194+
function get_variable_name(names::OrderedDict, var_to_diff, var_idx)
11951195
var_names = build_var_names(names, var_to_diff)
11961196
return var_names[var_idx]
11971197
end
@@ -1221,7 +1221,7 @@ function get_inline_backtrace(ir::IRCode, v::SSAValue)
12211221
return frames
12221222
end
12231223

1224-
function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector{<:LevelKey})
1224+
function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector) where {LevelKey, NameLevel}
12251225
for i = length(stack):-1:2
12261226
s = stack[i]
12271227
if !haskey(names, s)
@@ -1235,11 +1235,11 @@ end
12351235
is_valid_partial_scope(_) = false
12361236
is_valid_partial_scope(ps::PartialScope) = true
12371237
function is_valid_partial_scope(ps::PartialStruct)
1238-
if ps.typ === Scope
1238+
if ps.typ <: Scope
12391239
isa(ps.fields[2], Const) || return false
12401240
isa(ps.fields[2].val, Symbol) || return false
12411241
return is_valid_partial_scope(ps.fields[1])
1242-
elseif ps.typ === GenScope
1242+
elseif ps.typ <: GenScope
12431243
isa(ps.fields[1], Const) || return false
12441244
return is_valid_partial_scope(ps.fields[2])
12451245
else
@@ -1248,11 +1248,11 @@ function is_valid_partial_scope(ps::PartialStruct)
12481248
end
12491249

12501250
function sym_stack(ps::PartialStruct)
1251-
if ps.typ === Scope
1251+
if ps.typ <: Scope
12521252
sym = (ps.fields[2]::Const).val::Symbol
12531253
return pushfirst!(sym_stack(ps.fields[1]), sym)
12541254
else
1255-
@assert ps.typ === GenScope
1255+
@assert ps.typ <: GenScope
12561256
stack = sym_stack(ps.fields[2])
12571257
scope_identity = ((ps.fields[1]::Const).val)::Intrinsics.ScopeIdentity
12581258
stack[1] = Gen(scope_identity, stack[1])
@@ -1261,7 +1261,7 @@ function sym_stack(ps::PartialStruct)
12611261
end
12621262

12631263
sym_stack(ps::PartialScope) = LevelKey[ps]
1264-
function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
1264+
function record_scope!(ir::IRCode, names::OrderedDict, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
12651265
varssa::Vector, idx::Int, lens)
12661266

12671267
stack = sym_stack(scope)
@@ -1282,11 +1282,15 @@ function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scop
12821282
end
12831283

12841284
function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, val::NameLevel,
1285-
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
1285+
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}
12861286

12871287
haskey(names, key) || (names[key] = NameLevel())
12881288
existing = names[key]
1289-
for (offset, lens) in ((x->(only(findnz(mapping.var_coeffs[x].row)[1])), @o _.var),
1289+
function remap_var(x)
1290+
r = only(findnz(mapping.var_coeffs[x].row)[1]) - 1
1291+
return r
1292+
end
1293+
for (offset, lens) in ((remap_var, @o _.var),
12901294
(x->(x+obsoffset), @o _.obs),
12911295
(x->mapping.eqs[x], @o _.eq), (x->(x+epsoffset), @o _.eps))
12921296
if lens(val) !== nothing
@@ -1312,7 +1316,7 @@ function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, v
13121316
end
13131317

13141318
function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::Union{Scope, PartialStruct}, val::NameLevel,
1315-
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
1319+
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}
13161320

13171321
stack = sym_stack(key)
13181322
if isempty(stack)

src/analysis/debugging.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
using StateSelection
22
using StateSelection.BipartiteGraphs
33

4-
function build_var_names(names::OrderedDict{LevelKey, NameLevel}, var_to_diff)
4+
function build_var_names(names::OrderedDict, var_to_diff)
55
var_names = OrderedDict{Int,String}()
66
build_var_names!(var_names, names, var_to_diff)
77
return var_names
88
end
9-
function build_var_names!(var_names, names::OrderedDict{LevelKey, NameLevel}, var_to_diff, prefix=String[])
9+
function build_var_names!(var_names, names::OrderedDict, var_to_diff, prefix=String[])
1010
for name in keys(names)
1111
name_path = join(vcat(prefix..., name), ".")
1212
level = names[name]

src/analysis/interpreter.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -680,7 +680,7 @@ function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, t
680680
eq_mapping[idnum(template)] = idnum(arg)
681681
elseif CC.is_const_argtype(template)
682682
#@Core.Compiler.show (arg, template)
683-
@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
683+
#@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
684684
elseif isa(template, PartialScope)
685685
id = idnum(template)
686686
(id > length(applied_scopes)) && resize!(applied_scopes, id)
@@ -919,7 +919,7 @@ function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instr
919919
argtypes = CC.collect_argtypes(interp, stmt.args, nothing, irsv)[2:end]
920920
callee_result = dae_result_for_inst(interp, inst)
921921
callee_result === nothing && return RT(nothing, (false, false))
922-
if isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
922+
if isa(callee_result, UncompilableIPOResult) || isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
923923
return RT(nothing, (false, false))
924924
end
925925
mapping = CalleeMapping(CC.optimizer_lattice(interp), argtypes, callee_result)
@@ -1030,14 +1030,11 @@ end
10301030
# -----
10311031

10321032
function typeinf_dae(@nospecialize(tt), world::UInt=get_world_counter();
1033-
method_table::Union{Nothing,MethodTable} = nothing,
10341033
ipo_analysis_mode::Bool = false)
1035-
interp = DAEInterpreter(world; method_table, ipo_analysis_mode)
1036-
match = Base._which(tt;
1037-
method_table=CC.method_table(interp),
1038-
world=get_inference_world(interp),
1039-
raise=false)
1040-
match === nothing && single_match_error(tt)
1034+
interp = DAEInterpreter(world; ipo_analysis_mode)
1035+
match = Base._methods_by_ftype(tt, 1, world)
1036+
isempty(match) && single_match_error(tt)
1037+
match = only(match)
10411038
mi = CC.specialize_method(match)
10421039
ci = CC.typeinf_ext(interp, mi, Core.Compiler.SOURCE_MODE_ABI)
10431040
return interp, ci

src/cache.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ struct NameLevel
2424
obs::Union{Nothing, Int}
2525
eq::Union{Nothing, Int}
2626
eps::Union{Nothing, Int}
27-
children::Union{Nothing, OrderedDict{LevelKey, NameLevel}}
27+
# TODO: This should be an OrderedIdDict
28+
children::Union{Nothing, OrderedDict{Any, NameLevel}}
2829
end
2930
NameLevel() =
3031
NameLevel(nothing, nothing, nothing, nothing, nothing)
31-
NameLevel(children::OrderedDict{LevelKey, NameLevel}) =
32+
NameLevel(children::OrderedDict{Any, NameLevel}) =
3233
NameLevel(nothing, nothing, nothing, nothing, children)
3334

3435
struct UnsupportedIRException <: Exception
@@ -77,7 +78,7 @@ struct DAEIPOResult
7778
total_incidence::Vector{Any}
7879
eq_kind::Vector{VarEqKind}
7980
eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{SSAValue, Int}}}}
80-
names::OrderedDict{LevelKey, NameLevel}
81+
names::OrderedDict{Any, NameLevel} # TODO: OrderedIdDict
8182
nobserved::Int
8283
neps::Int
8384
ic_nzc::Int

src/irodesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ struct IRODESystem
158158
fallback_interp::AbstractInterpreter = Core.Compiler.NativeInterpreter(),
159159
debug_config = (;),
160160
ipo_analysis_mode = false,
161-
world::UInt=get_world_counter())
161+
world::UInt=Base.tls_world_age())
162162
debug_config = DebugConfig(debug_config, tt)
163163
@may_timeit debug_config "typeinf_dae" interp, ci = typeinf_dae(tt, world; ipo_analysis_mode)
164164
mi = ci.def
@@ -183,7 +183,7 @@ mutable struct IRTransformationState <: TransformationState{IRODESystem}
183183
ir::IRCode
184184
callback_func::Function
185185
structure::SystemStructure
186-
const names::OrderedDict{LevelKey, NameLevel}
186+
const names::OrderedDict{Any, NameLevel}
187187
const nobserved::Int
188188
const neps::Int
189189
const ic_nzc::Int

src/runtime.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,22 +83,22 @@ module Intrinsics
8383

8484
abstract type AbstractScope; end
8585

86-
struct Scope <: AbstractScope
86+
struct Scope{T} <: AbstractScope
8787
parent::AbstractScope
88-
name::Symbol
89-
Scope() = new()
90-
Scope(s::AbstractScope, sym::Symbol) = new(s, sym)
88+
name::T
89+
Scope() = new{Union{}}()
90+
Scope(s::AbstractScope, name::T) where {T} = new{T}(s, name)
9191
end
9292
(scope::Scope)(s::Symbol) = Scope(scope, s)
93-
# Scope(), but will less function calls, so marginally easier on the compiler
93+
# Scope(), but with less function calls, so marginally easier on the compiler
9494
const root_scope = Scope()
9595

9696
mutable struct ScopeIdentity; end
9797

98-
struct GenScope <: AbstractScope
98+
struct GenScope{T} <: AbstractScope
9999
identity::ScopeIdentity
100-
sc::Scope
101-
GenScope(sc::Scope) = new(ScopeIdentity(), sc)
100+
sc::Scope{T}
101+
GenScope(sc::Scope{T}) where {T} = new{T}(ScopeIdentity(), sc)
102102
end
103103
GenScope(parent::AbstractScope, name::Symbol) =
104104
GenScope(Scope(parent, name))

0 commit comments

Comments
 (0)