Skip to content

Commit ba2bfff

Browse files
test: test initialization on static array problems
1 parent 57e4778 commit ba2bfff

File tree

1 file changed

+53
-27
lines changed

1 file changed

+53
-27
lines changed

test/initializationsystem.jl

Lines changed: 53 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test
22
using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq, JumpProcesses
3-
using ForwardDiff
3+
using ForwardDiff, StaticArrays
44
using SymbolicIndexingInterface, SciMLStructures
55
using SciMLStructures: Tunable
66
using ModelingToolkit: t_nounits as t, D_nounits as D, observed
@@ -594,22 +594,36 @@ end
594594
@parameters p q
595595
@brownian a b
596596
x = _x(t)
597-
597+
sarray_ctor = splat(SVector)
598598
# `System` constructor creates appropriate type with mtkbuild
599599
# `Problem` and `alg` create the problem to test and allow calling `init` with
600600
# the correct solver.
601601
# `rhss` allows adding terms to the end of equations (only 2 equations allowed) to influence
602602
# the system type (brownian vars to turn it into an SDE).
603-
@testset "$Problem with $(SciMLBase.parameterless_type(alg))" for (System, Problem, alg, rhss) in [
604-
(ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)),
605-
(ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]),
606-
(ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]),
607-
(ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])
608-
]
603+
@testset "$Problem with $(SciMLBase.parameterless_type(alg)) and $ctor ctor" for ((System, Problem, alg, rhss), (ctor, expectedT)) in Iterators.product(
604+
[
605+
(ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)),
606+
(ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]),
607+
(ModelingToolkit.System, DDEProblem,
608+
MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]),
609+
(ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])
610+
],
611+
[(identity, Any), (sarray_ctor, SVector)])
612+
u0_constructor = p_constructor = ctor
613+
if ctor !== identity
614+
Problem = Problem{false}
615+
end
609616
function test_parameter(prob, sym, val)
610617
if prob.u0 !== nothing
618+
@test prob.u0 isa expectedT
611619
@test init(prob, alg).ps[sym] val
612620
end
621+
@test prob.p.tunable isa expectedT
622+
initprob = prob.f.initialization_data.initializeprob
623+
if state_values(initprob) !== nothing
624+
@test state_values(initprob) isa expectedT
625+
end
626+
@test parameter_values(initprob).tunable isa expectedT
613627
@test solve(prob, alg).ps[sym] val
614628
end
615629
function test_initializesystem(sys, u0map, pmap, p, equation)
@@ -626,72 +640,72 @@ end
626640
@mtkbuild sys = System(
627641
[D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => missing], guesses = [p => 1.0])
628642
pmap[p] = 2q
629-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
643+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
630644
test_parameter(prob, p, 2.0)
631645
prob2 = remake(prob; u0 = u0map, p = pmap)
632-
prob2.ps[p] = 0.0
646+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
633647
test_parameter(prob2, p, 2.0)
634648
# `missing` default, provided guess
635649
@mtkbuild sys = System(
636650
[D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; defaults = [p => missing], guesses = [p => 0.0])
637-
prob = Problem(sys, u0map, (0.0, 1.0))
651+
prob = Problem(sys, u0map, (0.0, 1.0); u0_constructor, p_constructor)
638652
test_parameter(prob, p, 2.0)
639653
test_initializesystem(sys, u0map, pmap, p, 0 ~ p - x - y)
640654
prob2 = remake(prob; u0 = u0map)
641-
prob2.ps[p] = 0.0
655+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
642656
test_parameter(prob2, p, 2.0)
643657

644658
# `missing` to Problem, equation from default
645659
@mtkbuild sys = System(
646660
[D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0])
647661
pmap[p] = missing
648-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
662+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
649663
test_parameter(prob, p, 2.0)
650664
test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p)
651665
prob2 = remake(prob; u0 = u0map, p = pmap)
652-
prob2.ps[p] = 0.0
666+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
653667
test_parameter(prob2, p, 2.0)
654668
# `missing` to Problem, provided guess
655669
@mtkbuild sys = System(
656670
[D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [p => 0.0])
657-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
671+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
658672
test_parameter(prob, p, 2.0)
659673
test_initializesystem(sys, u0map, pmap, p, 0 ~ x + y - p)
660674
prob2 = remake(prob; u0 = u0map, p = pmap)
661-
prob2.ps[p] = 0.0
675+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
662676
test_parameter(prob2, p, 2.0)
663677

664678
# No `missing`, default and guess
665679
@mtkbuild sys = System(
666680
[D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 0.0])
667681
delete!(pmap, p)
668-
prob = Problem(sys, u0map, (0.0, 1.0), pmap)
682+
prob = Problem(sys, u0map, (0.0, 1.0), pmap; u0_constructor, p_constructor)
669683
test_parameter(prob, p, 2.0)
670684
test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p)
671685
prob2 = remake(prob; u0 = u0map, p = pmap)
672-
prob2.ps[p] = 0.0
686+
prob2 = remake(prob2; p = setp_oop(prob2, p)(prob2, 0.0))
673687
test_parameter(prob2, p, 2.0)
674688

675689
# Default overridden by Problem, guess provided
676690
@mtkbuild sys = System(
677691
[D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0])
678692
_pmap = merge(pmap, Dict(p => q))
679-
prob = Problem(sys, u0map, (0.0, 1.0), _pmap)
693+
prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor)
680694
test_parameter(prob, p, _pmap[q])
681695
test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p)
682696
# Problem dependent value with guess, no `missing`
683697
@mtkbuild sys = System(
684698
[D(x) ~ y * q + p + rhss[1], D(y) ~ x * p + q + rhss[2]], t; guesses = [p => 0.0])
685699
_pmap = merge(pmap, Dict(p => 3q))
686-
prob = Problem(sys, u0map, (0.0, 1.0), _pmap)
700+
prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor)
687701
test_parameter(prob, p, 3pmap[q])
688702

689703
# Should not be solved for:
690704
# Override dependent default with direct value
691705
@mtkbuild sys = System(
692706
[D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0])
693707
_pmap = merge(pmap, Dict(p => 1.0))
694-
prob = Problem(sys, u0map, (0.0, 1.0), _pmap)
708+
prob = Problem(sys, u0map, (0.0, 1.0), _pmap; u0_constructor, p_constructor)
695709
@test prob.ps[p] 1.0
696710
initsys = prob.f.initialization_data.initializeprob.f.sys
697711
@test is_parameter(initsys, p)
@@ -700,7 +714,7 @@ end
700714
@parameters r::Int s::Int
701715
@mtkbuild sys = System(
702716
[D(x) ~ s * x + rhss[1], D(y) ~ y * r + rhss[2]], t; defaults = [s => 2r], guesses = [s => 1.0])
703-
prob = Problem(sys, u0map, (0.0, 1.0), [r => 1])
717+
prob = Problem(sys, u0map, (0.0, 1.0), [r => 1]; u0_constructor, p_constructor)
704718
@test prob.ps[r] == 1
705719
@test prob.ps[s] == 2
706720
initsys = prob.f.initialization_data.initializeprob.f.sys
@@ -714,7 +728,7 @@ end
714728

715729
# Unsatisfiable initialization
716730
prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0),
717-
[p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3])
731+
[p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3], u0_constructor, p_constructor)
718732
@test prob.f.initialization_data !== nothing
719733
@test solve(prob, alg).retcode == ReturnCode.InitialFailure
720734
cache = init(prob, alg)
@@ -791,8 +805,17 @@ end
791805

792806
prob_alg_combinations = zip(
793807
[NonlinearProblem, NonlinearLeastSquaresProblem], [nl_algs, nlls_algs])
794-
@testset "Parameter initialization" begin
808+
sarray_ctor = splat(SVector)
809+
@testset "Parameter initialization with ctor $ctor" for (ctor, expectedT) in [
810+
(identity, Any),
811+
(sarray_ctor, SVector)
812+
]
813+
u0_constructor = p_constructor = ctor
795814
function test_parameter(prob, alg, param, val)
815+
if prob.u0 !== nothing
816+
@test prob.u0 isa expectedT
817+
end
818+
@test prob.p.tunable isa expectedT
796819
integ = init(prob, alg)
797820
@test integ.ps[param]val rtol=1e-5
798821
# some algorithms are a little temperamental
@@ -818,19 +841,22 @@ end
818841
# guesses = [q => 1.0], initialization_eqs = [p^2 + q^2 + 2p * q ~ 0])
819842

820843
for (probT, algs) in prob_alg_combinations
821-
prob = probT(sys, [])
844+
if ctor != identity
845+
probT = probT{false}
846+
end
847+
prob = probT(sys, []; u0_constructor, p_constructor)
822848
@test prob.f.initialization_data !== nothing
823849
@test prob.f.initialization_data.initializeprobmap === nothing
824850
for alg in algs
825851
test_parameter(prob, alg, q, -2.0)
826852
end
827853

828854
# `update_initializeprob!` works
829-
prob.ps[p] = -2.0
855+
prob = remake(prob; p = setp_oop(prob, p)(prob, -2.0))
830856
for alg in algs
831857
test_parameter(prob, alg, q, 2.0)
832858
end
833-
prob.ps[p] = 2.0
859+
prob = remake(prob; p = setp_oop(prob, p)(prob, 2.0))
834860

835861
# `remake` works
836862
prob2 = remake(prob; p = [p => -2.0])

0 commit comments

Comments
 (0)