Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
682 changes: 542 additions & 140 deletions Manifest.toml

Large diffs are not rendered by default.

27 changes: 4 additions & 23 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ uuid = "32805668-c3d0-42c2-aafd-0d0a9857a104"
version = "1.21.0"
authors = ["JuliaHub, Inc. and other contributors"]

[workspace]
projects = ["test"]

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
CentralizedCaches = "d1073d05-2d26-4019-b855-dfa0385fef5e"
Expand Down Expand Up @@ -30,12 +33,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[sources]
ModelingToolkitStandardLibrary = {rev = "ox/dae_compatible5", url = "https://github.com/CedarEDA/ModelingToolkitStandardLibrary.jl"}
SciMLBase = {rev = "os/dae-get-du2", url = "https://github.com/CedarEDA/SciMLBase.jl"}
SciMLSensitivity = {rev = "kf/mindep2", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}

Expand All @@ -54,12 +56,10 @@ Cthulhu = "2.10.1"
DiffEqBase = "6.149.2"
Diffractor = "0.2.7"
ForwardDiff = "0.10.36"
ModelingToolkitStandardLibrary = "2.6.0"
NonlinearSolve = "3.5.0"
OrderedCollections = "1.6.3"
PrecompileTools = "1"
Preferences = "1.4"
Roots = "2.0.22"
SciMLBase = "2.24.0"
SciMLSensitivity = "7.47"
StateSelection = "0.2.0"
Expand All @@ -68,24 +68,5 @@ Sundials = "4.19"
SymbolicIndexingInterface = "0.3"
julia = "1.11"

[extras]
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[preferences.LinearSolve]
LoadMKL_JLL = false

[targets]
test = ["ControlSystemsBase", "DataInterpolations", "FiniteDiff", "FiniteDifferences", "IfElse", "InteractiveUtils", "ModelingToolkit", "ModelingToolkitStandardLibrary", "OrdinaryDiffEq", "SafeTestsets", "Sundials", "Test", "Roots", "StaticArrays"]
18 changes: 9 additions & 9 deletions ext/DAECompilerModelingToolkitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ function declare_parameters(model, struct_name)
backing::B
end
)


constructor_expr =:(
@generated function _check_parameter_names(::Type{$struct_name}, param_kwargs::NamedTuple)
unexpected_parameters = setdiff(fieldnames(param_kwargs), $param_names_tuple_expr)
Expand Down Expand Up @@ -108,7 +108,7 @@ function declare_parameters(model, struct_name)
if name === $param_name
return if hasfield(B, $param_name)
getfield(getfield(this, :backing), $param_name)
else
else
$param_value
end
end
Expand All @@ -118,7 +118,7 @@ function declare_parameters(model, struct_name)
return getfield(getfield(this, :backing), name)
))
getproperty_expr.args[end].args[end] = Expr(:block, getproperty_body...)

return Expr(:block, struct_expr, constructor_expr, propertynames_expr, getproperty_expr)
end

Expand Down Expand Up @@ -206,7 +206,7 @@ end

macro DAECompiler.declare_MTKConnector(mtk_component, ports...)
# We do need to do run time eval, because we can't decide what to construct with just lexical information.
# we need the values of the
# we need the values of the
:(Base.eval(@__MODULE__, $MTKConnector_AST($(esc(mtk_component)), $(esc.(ports)...))))
end

Expand All @@ -219,7 +219,7 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
end

while !isnothing(MTK.get_parent(model))
# Undo any call to structural_simplify
# Undo any call to structural_simplify
# (Should we give a warning here? They did waste CPU cycles simplfying it in first place)
model = MTK.get_parent(model)
end
Expand All @@ -239,11 +239,11 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)


struct_name = gensym(nameof(model))

return quote
$(declare_parameters(model, struct_name))

function (this::$struct_name)($(port_names...); dscope=$(_c(Scope))())
function (this::$struct_name)($(map(port->:($(port)::Float64), port_names)...); dscope=$(_c(Scope))())
$(declare_vars(model, :dscope))
$(declare_derivatives(state))
$(declare_equations(state, model, :dscope, ports))
Expand All @@ -258,4 +258,4 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
end


end # module
end # module
9 changes: 5 additions & 4 deletions ext/DAECompilerSciMLSensitivityExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ with one column per time step in `ts` and one one row per `variable`/`observed!`
"""
function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolution, syms::Vector{<:DAECompiler.ScopeRef}, ts=sol.t)
us, du_dparams = extract_local_sensitivities(sol, ts)
var_inds, obs_inds = DAECompiler.split_and_sort_syms(syms)

transformed_sys = DAECompiler.get_transformed_sys(sol)
sys = DAECompiler.get_sys(transformed_sys)
var_inds, obs_inds = DAECompiler.split_and_sort_syms(sys, syms)

dreconstruct! = get!(sol.prob.f.observed.derivative_cache, (var_inds, obs_inds, false)) do
DAECompiler.compile_batched_reconstruct_derivatives(transformed_sys, var_inds, obs_inds, false, false;)
end

num_params = length(du_dparams)
dout_vars_per_param = [similar(us, (length(var_inds), length(ts))) for _ in 1:num_params]
dout_obs_per_param = [similar(us, (length(obs_inds), length(ts))) for _ in 1:num_params]
Expand All @@ -67,7 +68,7 @@ function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolutio
end

return map(dout_vars_per_param, dout_obs_per_param) do dout_vars, dout_obs
DAECompiler.join_syms(syms, dout_vars, dout_obs, (var_inds, obs_inds))
DAECompiler.join_syms(sys, syms, dout_vars, dout_obs, (var_inds, obs_inds))
end
end

Expand Down
32 changes: 18 additions & 14 deletions src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ has_any_genscope(sc::PartialStruct) = false # TODO

function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
if isa(argt, Const)
@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
#@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
return argt
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
return PartialScope(add_scope!(which))
Expand Down Expand Up @@ -362,7 +362,7 @@ function make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_va
end

function resolve_genscopes(names)
new_names = OrderedDict{LevelKey, NameLevel}()
new_names = OrderedDict{Any, NameLevel}()
for (key, val) in collect(names)
if val.children !== nothing
@reset val.children = resolve_genscopes(val.children)
Expand Down Expand Up @@ -423,7 +423,7 @@ Perform the structural analysis on optimized code of `mi` and return `structure:
end
end

function refresh_identities(names::OrderedDict{LevelKey, NameLevel})
function refresh_identities(names::OrderedDict{LevelKey, NameLevel}) where {LevelKey, NameLevel}
new_names = OrderedDict{LevelKey, NameLevel}()
for (key, val) in names
if isa(key, Gen)
Expand Down Expand Up @@ -502,7 +502,7 @@ end
eq_kind = VarEqKind[]
warnings = UnsupportedIRException[]

names = OrderedDict{LevelKey, NameLevel}()
names = OrderedDict{Any, NameLevel}()

nsysmscopes = 0
ncallees = 0
Expand Down Expand Up @@ -1191,7 +1191,7 @@ function process_ipo_return!(ultimate_rt::PartialStruct, args...)
return PartialStruct(ultimate_rt.typ, fields), nimplicitoutpairs
end

function get_variable_name(names::OrderedDict{LevelKey, NameLevel}, var_to_diff, var_idx)
function get_variable_name(names::OrderedDict, var_to_diff, var_idx)
var_names = build_var_names(names, var_to_diff)
return var_names[var_idx]
end
Expand Down Expand Up @@ -1221,7 +1221,7 @@ function get_inline_backtrace(ir::IRCode, v::SSAValue)
return frames
end

function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector{<:LevelKey})
function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector) where {LevelKey, NameLevel}
for i = length(stack):-1:2
s = stack[i]
if !haskey(names, s)
Expand All @@ -1235,11 +1235,11 @@ end
is_valid_partial_scope(_) = false
is_valid_partial_scope(ps::PartialScope) = true
function is_valid_partial_scope(ps::PartialStruct)
if ps.typ === Scope
if ps.typ <: Scope
isa(ps.fields[2], Const) || return false
isa(ps.fields[2].val, Symbol) || return false
return is_valid_partial_scope(ps.fields[1])
elseif ps.typ === GenScope
elseif ps.typ <: GenScope
isa(ps.fields[1], Const) || return false
return is_valid_partial_scope(ps.fields[2])
else
Expand All @@ -1248,11 +1248,11 @@ function is_valid_partial_scope(ps::PartialStruct)
end

function sym_stack(ps::PartialStruct)
if ps.typ === Scope
if ps.typ <: Scope
sym = (ps.fields[2]::Const).val::Symbol
return pushfirst!(sym_stack(ps.fields[1]), sym)
else
@assert ps.typ === GenScope
@assert ps.typ <: GenScope
stack = sym_stack(ps.fields[2])
scope_identity = ((ps.fields[1]::Const).val)::Intrinsics.ScopeIdentity
stack[1] = Gen(scope_identity, stack[1])
Expand All @@ -1261,7 +1261,7 @@ function sym_stack(ps::PartialStruct)
end

sym_stack(ps::PartialScope) = LevelKey[ps]
function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
function record_scope!(ir::IRCode, names::OrderedDict, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
varssa::Vector, idx::Int, lens)

stack = sym_stack(scope)
Expand All @@ -1282,11 +1282,15 @@ function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scop
end

function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, val::NameLevel,
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}

haskey(names, key) || (names[key] = NameLevel())
existing = names[key]
for (offset, lens) in ((x->(only(findnz(mapping.var_coeffs[x].row)[1])), @o _.var),
function remap_var(x)
r = only(findnz(mapping.var_coeffs[x].row)[1]) - 1
return r
end
for (offset, lens) in ((remap_var, @o _.var),
(x->(x+obsoffset), @o _.obs),
(x->mapping.eqs[x], @o _.eq), (x->(x+epsoffset), @o _.eps))
if lens(val) !== nothing
Expand All @@ -1312,7 +1316,7 @@ function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, v
end

function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::Union{Scope, PartialStruct}, val::NameLevel,
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}

stack = sym_stack(key)
if isempty(stack)
Expand Down
4 changes: 2 additions & 2 deletions src/analysis/debugging.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using StateSelection
using StateSelection.BipartiteGraphs

function build_var_names(names::OrderedDict{LevelKey, NameLevel}, var_to_diff)
function build_var_names(names::OrderedDict, var_to_diff)
var_names = OrderedDict{Int,String}()
build_var_names!(var_names, names, var_to_diff)
return var_names
end
function build_var_names!(var_names, names::OrderedDict{LevelKey, NameLevel}, var_to_diff, prefix=String[])
function build_var_names!(var_names, names::OrderedDict, var_to_diff, prefix=String[])
for name in keys(names)
name_path = join(vcat(prefix..., name), ".")
level = names[name]
Expand Down
26 changes: 13 additions & 13 deletions src/analysis/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,13 @@ end
Diffractor.disable_forward(interp::DAEInterpreter) = CC.NativeInterpreter(interp.world)

function CC.InferenceParams(::DAEInterpreter)
return CC.InferenceParams(;
unoptimize_throw_blocks=false,
assume_bindings_static=true,
ignore_recursion_hardlimit=true)
args = (;
assume_bindings_static=true,
ignore_recursion_hardlimit=true)
if VERSION < v"1.12.0-DEV.1017"
args = (; unoptimize_throw_blocks=false, args...)
end
return CC.InferenceParams(; args...)
end
function CC.OptimizationParams(::DAEInterpreter)
opt_params = CC.OptimizationParams(;
Expand Down Expand Up @@ -680,7 +683,7 @@ function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, t
eq_mapping[idnum(template)] = idnum(arg)
elseif CC.is_const_argtype(template)
#@Core.Compiler.show (arg, template)
@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
#@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
elseif isa(template, PartialScope)
id = idnum(template)
(id > length(applied_scopes)) && resize!(applied_scopes, id)
Expand Down Expand Up @@ -919,7 +922,7 @@ function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instr
argtypes = CC.collect_argtypes(interp, stmt.args, nothing, irsv)[2:end]
callee_result = dae_result_for_inst(interp, inst)
callee_result === nothing && return RT(nothing, (false, false))
if isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
if isa(callee_result, UncompilableIPOResult) || isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
return RT(nothing, (false, false))
end
mapping = CalleeMapping(CC.optimizer_lattice(interp), argtypes, callee_result)
Expand Down Expand Up @@ -1030,14 +1033,11 @@ end
# -----

function typeinf_dae(@nospecialize(tt), world::UInt=get_world_counter();
method_table::Union{Nothing,MethodTable} = nothing,
ipo_analysis_mode::Bool = false)
interp = DAEInterpreter(world; method_table, ipo_analysis_mode)
match = Base._which(tt;
method_table=CC.method_table(interp),
world=get_inference_world(interp),
raise=false)
match === nothing && single_match_error(tt)
interp = DAEInterpreter(world; ipo_analysis_mode)
match = Base._methods_by_ftype(tt, 1, world)
isempty(match) && single_match_error(tt)
match = only(match)
mi = CC.specialize_method(match)
ci = CC.typeinf_ext(interp, mi, Core.Compiler.SOURCE_MODE_ABI)
return interp, ci
Expand Down
5 changes: 4 additions & 1 deletion src/analysis/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ CC.widenconst(::PartialScope) = Scope
CC.widenconst(pkv::PartialKeyValue) = widenconst(pkv.typ)
CC.:⊑(inc::Incidence, inc2) = CC.:⊑(inc2, Float64) && !isa(inc2, Const)

function CC._uniontypes(x::Incidence, ts)
function CC._uniontypes(x::Incidence, ts::Vector{Any})
u = x.typ
if isa(u, Union)
CC.push!(ts, is_non_incidence_type(u.a) ? u.a : Incidence(u.a, x.row, x.eps))
Expand Down Expand Up @@ -462,6 +462,9 @@ function CC._getfield_tfunc(🥬::DAELattice, @nospecialize(s00), @nospecialize(
return Union{}
end
rt = CC._getfield_tfunc(CC.widenlattice(🥬), s00.typ, name, setfield)
if rt == Union{}
return Union{}
end
if isempty(s00)
return Incidence(rt)
end
Expand Down
7 changes: 4 additions & 3 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ struct NameLevel
obs::Union{Nothing, Int}
eq::Union{Nothing, Int}
eps::Union{Nothing, Int}
children::Union{Nothing, OrderedDict{LevelKey, NameLevel}}
# TODO: This should be an OrderedIdDict
children::Union{Nothing, OrderedDict{Any, NameLevel}}
end
NameLevel() =
NameLevel(nothing, nothing, nothing, nothing, nothing)
NameLevel(children::OrderedDict{LevelKey, NameLevel}) =
NameLevel(children::OrderedDict{Any, NameLevel}) =
NameLevel(nothing, nothing, nothing, nothing, children)

struct UnsupportedIRException <: Exception
Expand Down Expand Up @@ -77,7 +78,7 @@ struct DAEIPOResult
total_incidence::Vector{Any}
eq_kind::Vector{VarEqKind}
eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{SSAValue, Int}}}}
names::OrderedDict{LevelKey, NameLevel}
names::OrderedDict{Any, NameLevel} # TODO: OrderedIdDict
nobserved::Int
neps::Int
ic_nzc::Int
Expand Down
Loading
Loading