Skip to content

Commit 6ed4ee5

Browse files
authored
Merge pull request #2249 from SciML/myb/syscompose
Fix `systems` keyword argument handling in `System`
2 parents 265f9ec + 0521114 commit 6ed4ee5

File tree

7 files changed

+46
-77
lines changed

7 files changed

+46
-77
lines changed

src/systems/abstractsystem.jl

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -472,15 +472,18 @@ function namespace_defaults(sys)
472472
for (k, v) in pairs(defs))
473473
end
474474

475-
function namespace_equations(sys::AbstractSystem)
475+
function namespace_equations(sys::AbstractSystem, ivs = independent_variables(sys))
476476
eqs = equations(sys)
477477
isempty(eqs) && return Equation[]
478-
map(eq -> namespace_equation(eq, sys), eqs)
478+
map(eq -> namespace_equation(eq, sys; ivs), eqs)
479479
end
480480

481-
function namespace_equation(eq::Equation, sys, n = nameof(sys))
482-
_lhs = namespace_expr(eq.lhs, sys, n)
483-
_rhs = namespace_expr(eq.rhs, sys, n)
481+
function namespace_equation(eq::Equation,
482+
sys,
483+
n = nameof(sys);
484+
ivs = independent_variables(sys))
485+
_lhs = namespace_expr(eq.lhs, sys, n; ivs)
486+
_rhs = namespace_expr(eq.rhs, sys, n; ivs)
484487
_lhs ~ _rhs
485488
end
486489

@@ -490,30 +493,29 @@ function namespace_assignment(eq::Assignment, sys)
490493
Assignment(_lhs, _rhs)
491494
end
492495

493-
function namespace_expr(O, sys, n = nameof(sys))
494-
ivs = independent_variables(sys)
496+
function namespace_expr(O, sys, n = nameof(sys); ivs = independent_variables(sys))
495497
O = unwrap(O)
496498
if any(isequal(O), ivs)
497499
return O
498500
elseif istree(O)
499501
T = typeof(O)
500502
renamed = let sys = sys, n = n, T = T
501-
map(a -> namespace_expr(a, sys, n)::Any, arguments(O))
503+
map(a -> namespace_expr(a, sys, n; ivs)::Any, arguments(O))
502504
end
503505
if isvariable(O)
504506
# Use renamespace so the scope is correct, and make sure to use the
505507
# metadata from the rescoped variable
506508
rescoped = renamespace(n, O)
507509
similarterm(O, operation(rescoped), renamed,
508-
metadata = metadata(rescoped))::T
510+
metadata = metadata(rescoped))::T
509511
else
510512
similarterm(O, operation(O), renamed, metadata = metadata(O))::T
511513
end
512514
elseif isvariable(O)
513515
renamespace(n, O)
514516
elseif O isa Array
515517
let sys = sys, n = n
516-
map(o -> namespace_expr(o, sys, n), O)
518+
map(o -> namespace_expr(o, sys, n; ivs), O)
517519
end
518520
else
519521
O
@@ -661,20 +663,6 @@ function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym)
661663
!isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym))
662664
end
663665

664-
struct AbstractSysToExpr
665-
sys::AbstractSystem
666-
states::Vector
667-
end
668-
AbstractSysToExpr(sys) = AbstractSysToExpr(sys, states(sys))
669-
function (f::AbstractSysToExpr)(O)
670-
!istree(O) && return toexpr(O)
671-
any(isequal(O), f.states) && return nameof(operation(O)) # variables
672-
if issym(operation(O))
673-
return build_expr(:call, Any[nameof(operation(O)); f.(arguments(O))])
674-
end
675-
return build_expr(:call, Any[operation(O); f.(arguments(O))])
676-
end
677-
678666
###
679667
### System utils
680668
###

src/systems/jumps/jumpsystem.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ function generate_rate_function(js::JumpSystem, rate)
181181
end
182182
rf = build_function(rate, states(js), parameters(js),
183183
get_iv(js),
184-
conv = states_to_sym(states(js)),
185184
expression = Val{true})
186185
end
187186

src/systems/optimization/optimizationsystem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ function generate_gradient(sys::OptimizationSystem, vs = states(sys), ps = param
124124
grad = calculate_gradient(sys)
125125
pre = get_preprocess_constants(grad)
126126
return build_function(grad, vs, ps; postprocess_fbody = pre,
127-
conv = AbstractSysToExpr(sys), kwargs...)
127+
kwargs...)
128128
end
129129

130130
function calculate_hessian(sys::OptimizationSystem)
@@ -140,14 +140,14 @@ function generate_hessian(sys::OptimizationSystem, vs = states(sys), ps = parame
140140
end
141141
pre = get_preprocess_constants(hess)
142142
return build_function(hess, vs, ps; postprocess_fbody = pre,
143-
conv = AbstractSysToExpr(sys), kwargs...)
143+
kwargs...)
144144
end
145145

146146
function generate_function(sys::OptimizationSystem, vs = states(sys), ps = parameters(sys);
147147
kwargs...)
148148
eqs = subs_constants(objective(sys))
149149
return build_function(eqs, vs, ps;
150-
conv = AbstractSysToExpr(sys), kwargs...)
150+
kwargs...)
151151
end
152152

153153
function namespace_objective(sys::AbstractSystem)

src/systems/systems.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
function System(eqs::AbstractVector{<:Equation}, iv = nothing, args...; name = nothing,
22
kw...)
3-
ODESystem(eqs, iv, args...; name, checks = false)
3+
ODESystem(eqs, iv, args...; name, kw..., checks = false)
44
end
55

66
"""

src/utils.jl

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,30 +50,6 @@ end
5050

5151
@deprecate substitute_expr!(expr, s) substitute(expr, s)
5252

53-
function states_to_sym(states::Set)
54-
function _states_to_sym(O)
55-
if O isa Equation
56-
Expr(:(=), _states_to_sym(O.lhs), _states_to_sym(O.rhs))
57-
elseif istree(O)
58-
op = operation(O)
59-
args = arguments(O)
60-
if issym(op)
61-
O in states && return tosymbol(O)
62-
# dependent variables
63-
return build_expr(:call, Any[nameof(op); _states_to_sym.(args)])
64-
else
65-
canonical, O = canonicalexpr(O)
66-
return canonical ? O : build_expr(:call, Any[op; _states_to_sym.(args)])
67-
end
68-
elseif O isa Num
69-
return _states_to_sym(value(O))
70-
else
71-
return toexpr(O)
72-
end
73-
end
74-
end
75-
states_to_sym(states) = states_to_sym(Set(states))
76-
7753
function todict(d)
7854
eltype(d) <: Pair || throw(ArgumentError("The variable-value mapping must be a Dict."))
7955
d isa Dict ? d : Dict(d)

test/dde.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,26 +83,30 @@ sys = structural_simplify(sys)
8383
prob_mtk = SDDEProblem(sys, [x(t) => 1.0 + t], tspan; constant_lags = (τ,));
8484
@test_nowarn sol_mtk = solve(prob_mtk, RKMil())
8585

86-
8786
@variables t
8887
D = Differential(t)
8988
@parameters x(..) a
9089

91-
function oscillator(;name, k=1.0, τ=0.01)
90+
function oscillator(; name, k = 1.0, τ = 0.01)
9291
@parameters k=k τ=τ
9392
@variables x(..)=0.1 y(t)=0.1 jcn(t)=0.0 delx(t)
9493
eqs = [D(x(t)) ~ y,
95-
D(y) ~ -k*x(t-τ)+jcn,
96-
delx ~ x(t-τ)]
97-
return System(eqs; name=name)
94+
D(y) ~ -k * x(t - τ) + jcn,
95+
delx ~ x(t - τ)]
96+
return System(eqs; name = name)
9897
end
9998

100-
@named osc1 = oscillator(k=1.0, τ=0.01)
101-
@named osc2 = oscillator(k=2.0, τ=0.04)
99+
systems = @named begin
100+
osc1 = oscillator(k = 1.0, τ = 0.01)
101+
osc2 = oscillator(k = 2.0, τ = 0.04)
102+
end
102103
eqs = [osc1.jcn ~ osc2.delx,
103-
osc2.jcn ~ osc1.delx]
104+
osc2.jcn ~ osc1.delx]
104105
@named coupledOsc = System(eqs, t)
105-
@named coupledOsc = compose(coupledOsc, [osc1, osc2])
106-
sys = structural_simplify(coupledOsc)
107-
@test length(equations(sys)) == 4
108-
@test length(states(sys)) == 4
106+
@named coupledOsc = compose(coupledOsc, systems)
107+
@named coupledOsc2 = System(eqs, t; systems)
108+
for coupledOsc in [coupledOsc, coupledOsc2]
109+
local sys = structural_simplify(coupledOsc)
110+
@test length(equations(sys)) == 4
111+
@test length(states(sys)) == 4
112+
end

test/optimizationsystem.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,18 +207,20 @@ end
207207
prob = OptimizationProblem(combinedsys, u0, p, grad = true, hess = true, cons_j = true,
208208
cons_h = true)
209209
@test prob.f.sys === combinedsys
210-
@test_broken SciMLBase.successful_retcode(solve(prob, Ipopt.Optimizer(); print_level = 0))
211-
#=
212-
@test sol.minimum < -1e5
213-
214-
prob = OptimizationProblem(sys2, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0],
215-
grad = true, hess = true, cons_j = true, cons_h = true)
216-
@test prob.f.sys === sys2
217-
sol = solve(prob, IPNewton())
218-
@test sol.minimum < 1.0
219-
sol = solve(prob, Ipopt.Optimizer(); print_level = 0)
220-
@test sol.minimum < 1.0
221-
=#
210+
@test_broken SciMLBase.successful_retcode(solve(prob,
211+
Ipopt.Optimizer();
212+
print_level = 0))
213+
#=
214+
@test sol.minimum < -1e5
215+
216+
prob = OptimizationProblem(sys2, [x => 0.0, y => 0.0], [a => 1.0, b => 100.0],
217+
grad = true, hess = true, cons_j = true, cons_h = true)
218+
@test prob.f.sys === sys2
219+
sol = solve(prob, IPNewton())
220+
@test sol.minimum < 1.0
221+
sol = solve(prob, Ipopt.Optimizer(); print_level = 0)
222+
@test sol.minimum < 1.0
223+
=#
222224
end
223225

224226
@testset "metadata" begin

0 commit comments

Comments
 (0)