Skip to content

Commit 50c1e16

Browse files
committed
Merge branch 'master' into myb/alias
2 parents acd1df0 + 5f9d51b commit 50c1e16

File tree

8 files changed

+147
-34
lines changed

8 files changed

+147
-34
lines changed

src/systems/abstractsystem.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,17 @@ function Base.getproperty(sys::AbstractSystem, name::Symbol)
222222
throw(error("Variable $name does not exist"))
223223
end
224224

225+
function Base.setproperty!(sys::AbstractSystem, prop::Symbol, val)
226+
if (pa = Sym{Parameter{Real}}(prop); pa in parameters(sys))
227+
sys.default_p[pa] = value(val)
228+
# comparing a Sym returns a symbolic expression
229+
elseif (st = Sym{Real}(prop); any(s->s.name==st.name, states(sys)))
230+
sys.default_u0[st] = value(val)
231+
else
232+
setfield!(sys, prop, val)
233+
end
234+
end
235+
225236
function renamespace(namespace, x)
226237
if x isa Num
227238
renamespace(namespace, value(x))
@@ -239,12 +250,12 @@ namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
239250

240251
function namespace_default_u0(sys)
241252
d_u0 = default_u0(sys)
242-
Dict(states(sys, k) => d_u0[k] for k in keys(d_u0))
253+
Dict(states(sys, k) => namespace_expr(d_u0[k], nameof(sys), independent_variable(sys)) for k in keys(d_u0))
243254
end
244255

245256
function namespace_default_p(sys)
246257
d_p = default_p(sys)
247-
Dict(parameters(sys, k) => d_p[k] for k in keys(d_p))
258+
Dict(parameters(sys, k) => namespace_expr(d_p[k], nameof(sys), independent_variable(sys)) for k in keys(d_p))
248259
end
249260

250261
function namespace_equations(sys::AbstractSystem)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -192,39 +192,53 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
192192
sparse = false, simplify=false,
193193
kwargs...) where {iip}
194194

195-
idx = iip ? 2 : 1
196-
f = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)[idx]
195+
f_oop, f_iip = generate_function(sys, dvs, ps; expression=Val{true}, kwargs...)
196+
fsym = gensym(:f)
197+
_f = quote
198+
$fsym(u,p,t) = $f_oop(u,p,t)
199+
$fsym(du,u,p,t) = $f_iip(du,u,p,t)
200+
end
201+
202+
tgradsym = gensym(:tgrad)
197203
if tgrad
198-
_tgrad = generate_tgrad(sys, dvs, ps;
204+
tgrad_oop, tgrad_iip = generate_tgrad(sys, dvs, ps;
199205
simplify=simplify,
200-
expression=Val{true}, kwargs...)[idx]
206+
expression=Val{true}, kwargs...)
207+
_tgrad = quote
208+
$tgradsym(u,p,t) = $tgrad_oop(u,p,t)
209+
$tgradsym(J,u,p,t) = $tgrad_iip(J,u,p,t)
210+
end
201211
else
202-
_tgrad = :nothing
212+
_tgrad = :($tgradsym = nothing)
203213
end
204214

215+
jacsym = gensym(:jac)
205216
if jac
206-
_jac = generate_jacobian(sys, dvs, ps;
217+
jac_oop,jac_iip = generate_jacobian(sys, dvs, ps;
207218
sparse=sparse, simplify=simplify,
208-
expression=Val{true}, kwargs...)[idx]
219+
expression=Val{true}, kwargs...)
220+
_jac = quote
221+
$jacsym(u,p,t) = $jac_oop(u,p,t)
222+
$jacsym(J,u,p,t) = $jac_iip(J,u,p,t)
223+
end
209224
else
210-
_jac = :nothing
225+
_jac = :($jacsym = nothing)
211226
end
212227

213228
M = calculate_massmatrix(sys)
214229

215230
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
216231

217232
jp_expr = sparse ? :(similar($(get_jac(sys)[]),Float64)) : :nothing
218-
219233
ex = quote
220-
f = $f
221-
tgrad = $_tgrad
222-
jac = $_jac
234+
$_f
235+
$_tgrad
236+
$_jac
223237
M = $_M
224238
ODEFunction{$iip}(
225-
f,
226-
jac = jac,
227-
tgrad = tgrad,
239+
$fsym,
240+
jac = $jacsym,
241+
tgrad = $tgradsym,
228242
mass_matrix = M,
229243
jac_prototype = $jp_expr,
230244
syms = $(Symbol.(states(sys))),
@@ -252,9 +266,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem,u0map,parammap;
252266
u0 = nothing
253267
end
254268

269+
defp = default_p(sys)
255270
if !(parammap isa DiffEqBase.NullParameters)
256271
parammap′ = lower_mapnames(parammap)
257-
p = varmap_to_vars(parammap′,ps; defaults=default_p(sys))
272+
p = varmap_to_vars(parammap′,ps; defaults=defp)
273+
elseif !isempty(defp)
274+
p = varmap_to_vars(Dict(),ps; defaults=defp)
258275
else
259276
p = ps
260277
end

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,5 @@ function flatten(sys::ODESystem)
224224
)
225225
end
226226
end
227+
228+
ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs...)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,13 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem,u0map,paramm
211211
ps = parameters(sys)
212212
u0map′ = lower_mapnames(u0map)
213213
u0 = varmap_to_vars(u0map′,dvs; defaults=default_u0(sys))
214+
defp = default_p(sys)
214215

215216
if !(parammap isa DiffEqBase.NullParameters)
216217
parammap′ = lower_mapnames(parammap)
217-
p = varmap_to_vars(parammap′,ps; defaults=default_p(sys))
218+
p = varmap_to_vars(parammap′,ps; defaults=defp)
219+
elseif !isempty(defp)
220+
p = varmap_to_vars(Dict(),ps; defaults=defp)
218221
else
219222
p = ps
220223
end

src/variables.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,10 @@ applicable.
213213
function varmap_to_vars(varmap::Dict, varlist; defaults=Dict())
214214
varmap = merge(defaults, varmap) # prefers the `varmap`
215215
varmap = Dict(value(k)=>value(varmap[k]) for k in keys(varmap))
216+
# resolve symbolic parameter expressions
217+
for (p, v) in pairs(varmap)
218+
varmap[p] = fixpoint_sub(v, varmap)
219+
end
216220
T′ = eltype(values(varmap))
217221
T = Base.isconcretetype(T′) ? T′ : Base.promote_typeof(values(varmap)...)
218222
out = Vector{T}(undef, length(varlist))

test/odesystem.jl

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,38 @@ generate_function(de, [x,y,z], [σ,ρ,β])
3535
jac_expr = generate_jacobian(de)
3636
jac = calculate_jacobian(de)
3737
jacfun = eval(jac_expr[2])
38-
# iip
39-
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
40-
du = zeros(3)
41-
u = collect(1:3)
42-
p = collect(4:6)
43-
f(du, u, p, 0.1)
44-
@test du == [4, 0, -16]
45-
J = zeros(3, 3)
46-
jacfun(J, u, p, t)
47-
# oop
48-
f = ODEFunction(de, [x,y,z], [σ,ρ,β])
49-
du = @SArray zeros(3)
50-
u = SVector(1:3...)
51-
p = SVector(4:6...)
52-
@test f(u, p, 0.1) === @SArray [4, 0, -16]
38+
39+
for f in [
40+
ODEFunction(de, [x,y,z], [σ,ρ,β], tgrad = true, jac = true),
41+
eval(ODEFunctionExpr(de, [x,y,z], [σ,ρ,β], tgrad = true, jac = true)),
42+
]
43+
# iip
44+
du = zeros(3)
45+
u = collect(1:3)
46+
p = collect(4:6)
47+
f.f(du, u, p, 0.1)
48+
@test du == [4, 0, -16]
49+
50+
# oop
51+
du = @SArray zeros(3)
52+
u = SVector(1:3...)
53+
p = SVector(4:6...)
54+
@test f.f(u, p, 0.1) === @SArray [4, 0, -16]
55+
56+
# iip vs oop
57+
du = zeros(3)
58+
g = similar(du)
59+
J = zeros(3, 3)
60+
u = collect(1:3)
61+
p = collect(4:6)
62+
f.f(du, u, p, 0.1)
63+
@test du == f(u, p, 0.1)
64+
f.tgrad(g, u, p, t)
65+
@test g == f.tgrad(u, p, t)
66+
f.jac(J, u, p, t)
67+
@test J == f.jac(u, p, t)
68+
end
69+
5370

5471
eqs = [D(x) ~ σ*(y-x),
5572
D(y) ~ x*-z)-y*t,
@@ -59,6 +76,8 @@ ModelingToolkit.calculate_tgrad(de)
5976

6077
tgrad_oop, tgrad_iip = eval.(ModelingToolkit.generate_tgrad(de))
6178

79+
u = SVector(1:3...)
80+
p = SVector(4:6...)
6281
@test tgrad_oop(u,p,t) == [0.0,-u[2],0.0]
6382
du = zeros(3)
6483
tgrad_iip(du,u,p,t)
@@ -254,3 +273,11 @@ sys = ODESystem(eqs, t)
254273
@test isequal(ModelingToolkit.get_iv(sys), t)
255274
@test isequal(states(sys), [x1, x2])
256275
@test isempty(parameters(sys))
276+
277+
# one equation ODESystem test
278+
@parameters t r
279+
@variables x(t)
280+
D = Differential(t)
281+
eq = D(x) ~ r*x
282+
ode = ODESystem(eq)
283+
@test equations(ode) == [eq]

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using SafeTestsets, Test
22

3+
@safetestset "Symbolic parameters test" begin include("symbolic_parameters.jl") end
34
@safetestset "Parsing Test" begin include("variable_parsing.jl") end
45
@safetestset "Differentiation Test" begin include("derivatives.jl") end
56
@safetestset "Simplify Test" begin include("simplify.jl") end

test/symbolic_parameters.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
using ModelingToolkit
2+
using NonlinearSolve
3+
using Test
4+
5+
@variables x y z u
6+
@parameters σ ρ β
7+
8+
eqs = [0 ~ σ*(y-x),
9+
0 ~ x*-z)-y,
10+
0 ~ x*y - β*z]
11+
12+
par = [
13+
σ => 1,
14+
ρ => 0.1+σ,
15+
β => ρ*1.1
16+
]
17+
u0 = Pair{Num, Any}[
18+
x => u,
19+
y => u,
20+
z => u-0.1,
21+
]
22+
ns = NonlinearSystem(eqs, [x,y,z],[σ,ρ,β], name=:ns, default_p=par, default_u0=u0)
23+
ns.y = u*1.1
24+
ModelingToolkit.default_p(ns)
25+
resolved = ModelingToolkit.varmap_to_vars(Dict(), parameters(ns), defaults=ModelingToolkit.default_p(ns))
26+
@test resolved == [1, 0.1+1, (0.1+1)*1.1]
27+
28+
prob = NonlinearProblem(ns, [u=>1.0], Pair[])
29+
@test prob.u0 == [1.0, 1.1, 0.9]
30+
@show sol = solve(prob,NewtonRaphson())
31+
32+
@variables a
33+
@parameters b
34+
top = NonlinearSystem([0 ~ -a + ns.x+b], [a], [b], systems=[ns], name=:top)
35+
top.b = ns.σ*0.5
36+
top.ns.x = u*0.5
37+
38+
res = ModelingToolkit.varmap_to_vars(Dict(), parameters(top), defaults=ModelingToolkit.default_p(top))
39+
@test res == [0.5, 1, 0.1+1, (0.1+1)*1.1]
40+
41+
prob = NonlinearProblem(top, [states(ns, u)=>1.0, a=>1.0], Pair[])
42+
@test prob.u0 == [1.0, 0.5, 1.1, 0.9]
43+
@show sol = solve(prob,NewtonRaphson())
44+
45+
# test NullParameters+defaults
46+
prob = NonlinearProblem(top, [states(ns, u)=>1.0, a=>1.0])
47+
@test prob.u0 == [1.0, 0.5, 1.1, 0.9]
48+
@show sol = solve(prob,NewtonRaphson())

0 commit comments

Comments
 (0)