Skip to content

Commit 951e128

Browse files
authored
Merge branch 'master' into speed
2 parents 4346094 + 709148e commit 951e128

File tree

8 files changed

+179
-47
lines changed

8 files changed

+179
-47
lines changed

src/systems/abstractsystem.jl

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -746,38 +746,64 @@ end
746746
abstract type SymScope end
747747

748748
struct LocalScope <: SymScope end
749-
function LocalScope(sym::Union{Num, Symbolic})
749+
function LocalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
750750
apply_to_variables(sym) do sym
751-
setmetadata(sym, SymScope, LocalScope())
751+
if istree(sym) && operation(sym) === getindex
752+
args = arguments(sym)
753+
a1 = setmetadata(args[1], SymScope, LocalScope())
754+
similarterm(sym, operation(sym), [a1, args[2:end]...])
755+
else
756+
setmetadata(sym, SymScope, LocalScope())
757+
end
752758
end
753759
end
754760

755761
struct ParentScope <: SymScope
756762
parent::SymScope
757763
end
758-
function ParentScope(sym::Union{Num, Symbolic})
764+
function ParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
759765
apply_to_variables(sym) do sym
760-
setmetadata(sym, SymScope,
761-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
766+
if istree(sym) && operation(sym) === getindex
767+
args = arguments(sym)
768+
a1 = setmetadata(args[1], SymScope,
769+
ParentScope(getmetadata(value(args[1]), SymScope, LocalScope())))
770+
similarterm(sym, operation(sym), [a1, args[2:end]...])
771+
else
772+
setmetadata(sym, SymScope,
773+
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
774+
end
762775
end
763776
end
764777

765778
struct DelayParentScope <: SymScope
766779
parent::SymScope
767780
N::Int
768781
end
769-
function DelayParentScope(sym::Union{Num, Symbolic}, N)
782+
function DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}, N)
770783
apply_to_variables(sym) do sym
771-
setmetadata(sym, SymScope,
772-
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
784+
if istree(sym) && operation(sym) == getindex
785+
args = arguments(sym)
786+
a1 = setmetadata(args[1], SymScope,
787+
DelayParentScope(getmetadata(value(args[1]), SymScope, LocalScope()), N))
788+
similarterm(sym, operation(sym), [a1, args[2:end]...])
789+
else
790+
setmetadata(sym, SymScope,
791+
DelayParentScope(getmetadata(value(sym), SymScope, LocalScope()), N))
792+
end
773793
end
774794
end
775-
DelayParentScope(sym::Union{Num, Symbolic}) = DelayParentScope(sym, 1)
795+
DelayParentScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}}) = DelayParentScope(sym, 1)
776796

777797
struct GlobalScope <: SymScope end
778-
function GlobalScope(sym::Union{Num, Symbolic})
798+
function GlobalScope(sym::Union{Num, Symbolic, Symbolics.Arr{Num}})
779799
apply_to_variables(sym) do sym
780-
setmetadata(sym, SymScope, GlobalScope())
800+
if istree(sym) && operation(sym) == getindex
801+
args = arguments(sym)
802+
a1 = setmetadata(args[1], SymScope, GlobalScope())
803+
similarterm(sym, operation(sym), [a1, args[2:end]...])
804+
else
805+
setmetadata(sym, SymScope, GlobalScope())
806+
end
781807
end
782808
end
783809

@@ -793,6 +819,11 @@ function renamespace(sys, x)
793819
return similarterm(x, operation(x),
794820
Any[renamespace(sys, only(arguments(x)))])::T
795821
end
822+
if istree(x) && operation(x) === getindex
823+
args = arguments(x)
824+
return similarterm(
825+
x, operation(x), vcat(renamespace(sys, args[1]), args[2:end]))::T
826+
end
796827
let scope = getmetadata(x, SymScope, LocalScope())
797828
if scope isa LocalScope
798829
rename(x, renamespace(getname(sys), getname(x)))::T
@@ -849,7 +880,8 @@ function namespace_assignment(eq::Assignment, sys)
849880
Assignment(_lhs, _rhs)
850881
end
851882

852-
function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys))
883+
function namespace_expr(
884+
O, sys, n = nameof(sys); ivs = independent_variables(sys))
853885
O = unwrap(O)
854886
if any(isequal(O), ivs)
855887
return O
@@ -1500,8 +1532,7 @@ function default_to_parentscope(v)
15001532
uv isa Symbolic || return v
15011533
apply_to_variables(v) do sym
15021534
if !hasmetadata(uv, SymScope)
1503-
setmetadata(sym, SymScope,
1504-
ParentScope(getmetadata(value(sym), SymScope, LocalScope())))
1535+
ParentScope(sym)
15051536
else
15061537
sym
15071538
end

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -876,31 +876,34 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
876876
parammap = Dict(unwrap.(parameters(sys)) .=> parammap)
877877
end
878878
end
879-
clockedparammap = Dict()
880-
defs = ModelingToolkit.get_defaults(sys)
881-
for v in ps
882-
v = unwrap(v)
883-
is_discrete_domain(v) || continue
884-
op = operation(v)
885-
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
886-
haskey(parammap, v)
887-
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
879+
880+
if has_discrete_subsystems(sys) && get_discrete_subsystems(sys) !== nothing
881+
clockedparammap = Dict()
882+
defs = ModelingToolkit.get_defaults(sys)
883+
for v in ps
884+
v = unwrap(v)
885+
is_discrete_domain(v) || continue
886+
op = operation(v)
887+
if !isa(op, Symbolics.Operator) && parammap != SciMLBase.NullParameters() &&
888+
haskey(parammap, v)
889+
error("Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v, provide the condition for $(Shift(iv, -1)(v)).")
890+
end
891+
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
892+
if parammap != SciMLBase.NullParameters() &&
893+
(val = get(parammap, shiftedv, nothing)) !== nothing
894+
clockedparammap[v] = val
895+
elseif op isa Shift
896+
root = arguments(v)[1]
897+
haskey(defs, root) || error("Initial condition for $v not provided.")
898+
clockedparammap[v] = defs[root]
899+
end
888900
end
889-
shiftedv = StructuralTransformations.simplify_shifts(Shift(iv, -1)(v))
890-
if parammap != SciMLBase.NullParameters() &&
891-
(val = get(parammap, shiftedv, nothing)) !== nothing
892-
clockedparammap[v] = val
893-
elseif op isa Shift
894-
root = arguments(v)[1]
895-
haskey(defs, root) || error("Initial condition for $v not provided.")
896-
clockedparammap[v] = defs[root]
901+
parammap = if parammap == SciMLBase.NullParameters()
902+
clockedparammap
903+
else
904+
merge(parammap, clockedparammap)
897905
end
898906
end
899-
parammap = if parammap == SciMLBase.NullParameters()
900-
clockedparammap
901-
else
902-
merge(parammap, clockedparammap)
903-
end
904907
# TODO: make it work with clocks
905908
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
906909
if sys isa ODESystem && build_initializeprob &&
@@ -931,7 +934,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
931934
if has_index_cache(sys) && get_index_cache(sys) !== nothing
932935
u0, defs = get_u0(sys, trueinit, parammap; symbolic_u0)
933936
check_eqs_u0(eqs, dvs, u0; kwargs...)
934-
p = MTKParameters(sys, parammap, trueinit)
937+
p = if parammap === nothing ||
938+
parammap == SciMLBase.NullParameters() && isempty(defs)
939+
nothing
940+
else
941+
MTKParameters(sys, parammap, trueinit)
942+
end
935943
else
936944
u0, p, defs = get_u0_p(sys,
937945
trueinit,
@@ -1592,7 +1600,6 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
15921600
if !iscomplete(sys)
15931601
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`")
15941602
end
1595-
15961603
if isempty(u0map) && get_initializesystem(sys) !== nothing
15971604
isys = get_initializesystem(sys)
15981605
elseif isempty(u0map) && get_initializesystem(sys) === nothing
@@ -1620,9 +1627,9 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
16201627
@warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
16211628
end
16221629

1623-
parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
1624-
[get_iv(sys) => t] :
1625-
merge(todict(parammap), Dict(get_iv(sys) => t))
1630+
parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ?
1631+
[get_iv(sys) => t] :
1632+
merge(todict(parammap), Dict(get_iv(sys) => t))
16261633

16271634
if neqs == nunknown
16281635
NonlinearProblem(isys, guesses, parammap)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ function generate_function(
241241
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
242242

243243
p = reorder_parameters(sys, value.(ps))
244-
@show p ps
245244
return build_function(rhss, value.(dvs), p...; postprocess_fbody = pre,
246245
states = sol_states, kwargs...)
247246
end
@@ -395,7 +394,6 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
395394
eqs = equations(sys)
396395
dvs = unknowns(sys)
397396
ps = full_parameters(sys)
398-
399397
if has_index_cache(sys) && get_index_cache(sys) !== nothing
400398
u0, defs = get_u0(sys, u0map, parammap)
401399
check_eqs_u0(eqs, dvs, u0; kwargs...)

src/systems/parameter_buffer.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,32 @@ function MTKParameters(
4343
p = merge(defs, p)
4444
p = merge(Dict(unwrap(k) => v for (k, v) in p),
4545
Dict(default_toterm(unwrap(k)) => v for (k, v) in p))
46-
p = Dict(k => fixpoint_sub(v, p) for (k, v) in p)
46+
p = Dict(unwrap(k) => fixpoint_sub(v, p) for (k, v) in p)
4747
for (sym, _) in p
4848
if istree(sym) && operation(sym) === getindex &&
4949
first(arguments(sym)) in all_ps
5050
error("Scalarized parameter values ($sym) are not supported. Instead of `[p[1] => 1.0, p[2] => 2.0]` use `[p => [1.0, 2.0]]`")
5151
end
5252
end
5353

54+
missing_params = Set()
55+
for idxmap in (ic.tunable_idx, ic.discrete_idx, ic.constant_idx, ic.nonnumeric_idx)
56+
for sym in keys(idxmap)
57+
sym isa Symbol && continue
58+
haskey(p, sym) && continue
59+
hasname(sym) && haskey(p, getname(sym)) && continue
60+
ttsym = default_toterm(sym)
61+
haskey(p, ttsym) && continue
62+
hasname(ttsym) && haskey(p, getname(ttsym)) && continue
63+
64+
istree(sym) && operation(sym) === getindex && haskey(p, arguments(sym)[1]) &&
65+
continue
66+
push!(missing_params, sym)
67+
end
68+
end
69+
70+
isempty(missing_params) || throw(MissingVariablesError(collect(missing_params)))
71+
5472
tunable_buffer = Tuple(Vector{temp.type}(undef, temp.length)
5573
for temp in ic.tunable_buffer_sizes)
5674
disc_buffer = Tuple(Vector{temp.type}(undef, temp.length)

test/input_output_handling.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ end
2020
@named sys = ODESystem([D(x) ~ -x + u], t) # both u and x are unbound
2121
@named sys1 = ODESystem([D(x) ~ -x + v[1] + v[2]], t) # both v and x are unbound
2222
@named sys2 = ODESystem([D(x) ~ -sys.x], t, systems = [sys]) # this binds sys.x in the context of sys2, sys2.x is still unbound
23-
@named sys21 = ODESystem([D(x) ~ -sys.x], t, systems = [sys1]) # this binds sys.x in the context of sys2, sys2.x is still unbound
23+
@named sys21 = ODESystem([D(x) ~ -sys1.x], t, systems = [sys1]) # this binds sys.x in the context of sys2, sys2.x is still unbound
2424
@named sys3 = ODESystem([D(x) ~ -sys.x + sys.u], t, systems = [sys]) # This binds both sys.x and sys.u
25-
@named sys31 = ODESystem([D(x) ~ -sys.x + sys1.v[1]], t, systems = [sys1]) # This binds both sys.x and sys1.v[1]
25+
@named sys31 = ODESystem([D(x) ~ -sys1.x + sys1.v[1]], t, systems = [sys1]) # This binds both sys.x and sys1.v[1]
2626

2727
@named sys4 = ODESystem([D(x) ~ -sys.x, u ~ sys.u], t, systems = [sys]) # This binds both sys.x and sys3.u, this system is one layer deeper than the previous. u is directly forwarded to sys.u, and in this case sys.u is bound while u is not
2828

@@ -43,7 +43,7 @@ end
4343
@test is_bound(sys2, sys.x)
4444
@test !is_bound(sys2, sys.u)
4545
@test !is_bound(sys2, sys2.sys.u)
46-
@test is_bound(sys21, sys.x)
46+
@test is_bound(sys21, sys1.x)
4747
@test !is_bound(sys21, sys1.v[1])
4848
@test !is_bound(sys21, sys1.v[2])
4949
@test is_bound(sys31, sys1.v[1])

test/mtkparameters.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,20 @@ newps = remake_buffer(sys, ps, Dict(p => 1.0f0))
123123
@test newps.tunable[1] isa Vector{Float32}
124124
@test newps.tunable[1] == [1.0f0, 2.0f0, 3.0f0]
125125

126+
# Issue#2624
127+
@parameters p d
128+
@variables X(t)
129+
eqs = [D(X) ~ p - d * X]
130+
@mtkbuild sys = ODESystem(eqs, t)
131+
132+
u0 = [X => 1.0]
133+
tspan = (0.0, 100.0)
134+
ps = [p => 1.0] # Value for `d` is missing
135+
136+
@test_throws ModelingToolkit.MissingVariablesError ODEProblem(sys, u0, tspan, ps)
137+
@test_nowarn ODEProblem(sys, u0, tspan, [ps..., d => 1.0])
138+
139+
126140
# JET tests
127141

128142
# scalar parameters only

test/odesystem.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1066,3 +1066,57 @@ prob = SteadyStateProblem(sys, u0, p)
10661066
@test prob isa SteadyStateProblem
10671067
prob = SteadyStateProblem(ODEProblem(sys, u0, (0.0, 10.0), p))
10681068
@test prob isa SteadyStateProblem
1069+
1070+
# Issue#2344
1071+
using ModelingToolkitStandardLibrary.Blocks
1072+
1073+
function FML2(; name)
1074+
@parameters begin
1075+
k2[1:1] = [1.0]
1076+
end
1077+
systems = @named begin
1078+
constant = Constant(k = k2[1])
1079+
end
1080+
@variables begin
1081+
x(t) = 0
1082+
end
1083+
eqs = [
1084+
D(x) ~ constant.output.u + k2[1]
1085+
]
1086+
ODESystem(eqs, t; systems, name)
1087+
end
1088+
1089+
@mtkbuild model = FML2()
1090+
1091+
@test isequal(ModelingToolkit.defaults(model)[model.constant.k], model.k2[1])
1092+
@test_nowarn ODEProblem(model, [], (0.0, 10.0))
1093+
1094+
# Issue#2477
1095+
function RealExpression(; name, y)
1096+
vars = @variables begin
1097+
u(t)
1098+
end
1099+
eqns = [
1100+
u ~ y
1101+
]
1102+
sys = ODESystem(eqns, t, vars, []; name)
1103+
end
1104+
1105+
function RealExpressionSystem(; name)
1106+
vars = @variables begin
1107+
x(t)
1108+
z(t)[1:1]
1109+
end # doing a collect on z doesn't work either.
1110+
@named e1 = RealExpression(y = x) # This works perfectly.
1111+
@named e2 = RealExpression(y = z[1]) # This bugs. However, `full_equations(e2)` works as expected.
1112+
systems = [e1, e2]
1113+
ODESystem(Equation[], t, Iterators.flatten(vars), []; systems, name)
1114+
end
1115+
1116+
@named sys = RealExpressionSystem()
1117+
sys = complete(sys)
1118+
@test Set(equations(sys)) == Set([sys.e1.u ~ sys.x, sys.e2.u ~ sys.z[1]])
1119+
tearing_state = TearingState(expand_connections(sys))
1120+
ts_vars = tearing_state.fullvars
1121+
orig_vars = unknowns(sys)
1122+
@test isempty(setdiff(ts_vars, orig_vars))

test/variable_scope.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,13 @@ ps = ModelingToolkit.getname.(parameters(level3))
7373
@test isequal(ps[4], :level2₊level0₊d)
7474
@test isequal(ps[5], :level1₊level0₊e)
7575
@test isequal(ps[6], :f)
76+
77+
# Issue@2252
78+
# Tests from PR#2354
79+
@parameters xx[1:2]
80+
arr_p = [ParentScope(xx[1]), xx[2]]
81+
arr0 = ODESystem(Equation[], t, [], arr_p; name = :arr0)
82+
arr1 = ODESystem(Equation[], t, [], []; name = :arr1) arr0
83+
arr_ps = ModelingToolkit.getname.(parameters(arr1))
84+
@test isequal(arr_ps[1], Symbol("xx"))
85+
@test isequal(arr_ps[2], Symbol("arr0₊xx"))

0 commit comments

Comments
 (0)