Skip to content

Commit 3845403

Browse files
test: fix MTK remake tests, use SCCNonlinearProblem codegen
1 parent e78e0d5 commit 3845403

File tree

1 file changed

+26
-36
lines changed

1 file changed

+26
-36
lines changed

test/downstream/modelingtoolkit_remake.jl

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Optimization
66
using OptimizationOptimJL
77
using ForwardDiff
88
using SciMLStructures
9+
using Test
910

1011
probs = []
1112
syss = []
@@ -67,30 +68,21 @@ push!(probs, OptimizationProblem(optsys, u0, p))
6768
k = ShiftIndex(t)
6869
@mtkbuild discsys = DiscreteSystem(
6970
[x ~ x(k - 1) * ρ + y(k - 2), y ~ y(k - 1) * σ - z(k - 2), z ~ z(k - 1) * β + x(k - 2)],
70-
t)
71+
t; defaults = [x => 1.0, y => 1.0, z => 1.0])
7172
# Roundabout method to avoid having to specify values for previous timestep
72-
fn = DiscreteFunction(discsys)
73-
ps = ModelingToolkit.MTKParameters(discsys, p)
74-
discu0 = Dict([u0..., x(k - 1) => 0.0, y(k - 1) => 0.0, z(k - 1) => 0.0])
73+
discprob = DiscreteProblem(discsys, [], (0, 10), p)
74+
for (var, v) in u0
75+
discprob[var] = v
76+
discprob[var(k-1)] = 0.0
77+
end
7578
push!(syss, discsys)
76-
push!(probs, DiscreteProblem(fn, getindex.((discu0,), unknowns(discsys)), (0, 10), ps))
77-
78-
# TODO: Rewrite this example when the MTK codegen is merged
79-
@named sys1 = NonlinearSystem(
80-
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
81-
sys1 = complete(sys1)
82-
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
83-
sys2 = complete(sys2)
84-
@named fullsys = NonlinearSystem(
79+
push!(probs, discprob)
80+
81+
@mtkbuild sys = NonlinearSystem(
8582
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
8683
[x, y, z], [σ, β, ρ])
87-
fullsys = complete(fullsys)
88-
89-
prob1 = NonlinearProblem(sys1, u0, p)
90-
prob2 = NonlinearProblem(sys2, u0, prob1.p)
91-
sccprob = SCCNonlinearProblem(
92-
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)
93-
push!(syss, fullsys)
84+
sccprob = SCCNonlinearProblem(sys, u0, p)
85+
push!(syss, sys)
9486
push!(probs, sccprob)
9587

9688
for (sys, prob) in zip(syss, probs)
@@ -273,7 +265,9 @@ end
273265
function SciMLBase.detect_cycles(
274266
::ModelingToolkit.AbstractSystem, varmap::Dict{Any, Any}, vars)
275267
for sym in vars
276-
if symbolic_type(ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)) !=
268+
newval = ModelingToolkit.fixpoint_sub(sym, varmap; maxiters = 10)
269+
vs = ModelingToolkit.vars(newval)
270+
if !isempty(vars) && any(in(Set(vars)), vs)
277271
NotSymbolic()
278272
return true
279273
end
@@ -296,15 +290,9 @@ end
296290
end
297291

298292
@testset "SCCNonlinearProblem" begin
299-
@named sys1 = NonlinearSystem(
300-
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2], [x, y], [σ, β, ρ])
301-
sys1 = complete(sys1)
302-
@named sys2 = NonlinearSystem([0 ~ z^2 - 4z + 4], [z], [])
303-
sys2 = complete(sys2)
304-
@named fullsys = NonlinearSystem(
293+
@mtkbuild fullsys = NonlinearSystem(
305294
[0 ~ x^3 * β + y^3 * ρ - σ, 0 ~ x^2 + 2x * y + y^2, 0 ~ z^2 - 4z + 4],
306295
[x, y, z], [σ, β, ρ])
307-
fullsys = complete(fullsys)
308296

309297
u0 = [x => 1.0,
310298
y => 0.0,
@@ -314,15 +302,17 @@ end
314302
ρ => 10.0,
315303
β => 8 / 3]
316304

317-
prob1 = NonlinearProblem(sys1, u0, p)
318-
prob2 = NonlinearProblem(sys2, u0, prob1.p)
319-
sccprob = SCCNonlinearProblem(
320-
[prob1, prob2], [Returns(nothing), Returns(nothing)], prob1.p, true; sys = fullsys)
305+
sccprob = SCCNonlinearProblem(fullsys, u0, p)
321306

322307
sccprob2 = remake(sccprob; u0 = 2ones(3))
323308
@test state_values(sccprob2) 2ones(3)
324-
@test sccprob2.probs[1].u0 2ones(2)
325-
@test sccprob2.probs[2].u0 2ones(1)
309+
prob1, prob2 = if length(state_values(sccprob2.probs[1])) == 1
310+
sccprob2.probs[2], sccprob2.probs[1]
311+
else
312+
sccprob2.probs[1], sccprob2.probs[2]
313+
end
314+
@test prob1.u0 2ones(2)
315+
@test prob2.u0 2ones(1)
326316
@test sccprob2.explicitfuns! !== missing
327317
@test sccprob2.f.sys !== missing
328318

@@ -333,9 +323,9 @@ end
333323
@test_throws ["parameters_alias", "SCCNonlinearProblem"] remake(
334324
sccprob; parameters_alias = false, p ==> 2.0])
335325

336-
newp = remake_buffer(sys1, prob1.p, [σ], [3.0])
326+
newp = remake_buffer(sccprob.f.sys, sccprob.p, [σ], [3.0])
337327
sccprob4 = remake(sccprob; parameters_alias = false, p = newp,
338-
probs = [remake(prob1; p ==> 3.0]), prob2])
328+
probs = [remake(sccprob.probs[1]; p ==> 3.0]), sccprob.probs[2]])
339329
@test !sccprob4.parameters_alias
340330
@test sccprob4.p !== sccprob4.probs[1].p
341331
@test sccprob4.p !== sccprob4.probs[2].p

0 commit comments

Comments
 (0)