Skip to content

Commit 8af217d

Browse files
test: rewrite hybrid system tests to not use MTK
1 parent 51f554b commit 8af217d

File tree

2 files changed

+338
-1
lines changed

2 files changed

+338
-1
lines changed

test/downstream/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
3+
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
34
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
45
JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
56
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"

test/downstream/comprehensive_indexing.jl

Lines changed: 337 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization,
22
OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase,
3-
SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface, Test
3+
SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface,
4+
DiffEqCallbacks, Test
45
using ModelingToolkit: t_nounits as t, D_nounits as D
56

67
# Sets rnd number.
@@ -528,3 +529,338 @@ end
528529
@test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w])
529530
@test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y])
530531
end
532+
533+
@testset "Discrete save indexing" begin
534+
struct NumSymbolCache{S}
535+
sc::S
536+
end
537+
SymbolicIndexingInterface.symbolic_container(s::NumSymbolCache) = s.sc
538+
function SymbolicIndexingInterface.is_observed(s::NumSymbolCache, x)
539+
return symbolic_type(x) != NotSymbolic() && !is_variable(s, x) &&
540+
!is_parameter(s, x) && !is_independent_variable(s, x)
541+
end
542+
function SymbolicIndexingInterface.observed(s::NumSymbolCache, x)
543+
res = ModelingToolkit.build_function(x,
544+
sort(variable_symbols(s); by = Base.Fix1(variable_index, s)),
545+
sort(parameter_symbols(s), by = Base.Fix1(parameter_index, s)),
546+
independent_variable_symbols(s)[]; expression = Val(false))
547+
if res isa Tuple
548+
return let oopfn = res[1], iipfn = res[2]
549+
fn(out, u, p, t) = iipfn(out, u, p, t)
550+
fn(u, p, t) = oopfn(u, p, t)
551+
fn
552+
end
553+
else
554+
return res
555+
end
556+
end
557+
function SymbolicIndexingInterface.parameter_observed(s::NumSymbolCache, x)
558+
res = ModelingToolkit.build_function(x,
559+
sort(parameter_symbols(s), by = Base.Fix1(parameter_index, s)),
560+
independent_variable_symbols(s)[]; expression = Val(false))
561+
if res isa Tuple
562+
return let oopfn = res[1], iipfn = res[2]
563+
fn(out, p, t) = iipfn(out, p, t)
564+
fn(p, t) = oopfn(p, t)
565+
fn
566+
end
567+
else
568+
return res
569+
end
570+
end
571+
function SymbolicIndexingInterface.get_all_timeseries_indexes(s::NumSymbolCache, x)
572+
if symbolic_type(x) == NotSymbolic()
573+
x = ModelingToolkit.unwrap.(x)
574+
else
575+
x = ModelingToolkit.unwrap(x)
576+
end
577+
vars = ModelingToolkit.vars(x)
578+
return mapreduce(union, vars; init = Set()) do sym
579+
if is_variable(s, sym)
580+
Set([ContinuousTimeseries()])
581+
elseif is_parameter(s, sym) && is_timeseries_parameter(s, sym)
582+
Set([timeseries_parameter_index(s, sym).timeseries_idx])
583+
else
584+
Set()
585+
end
586+
end
587+
end
588+
function SymbolicIndexingInterface.with_updated_parameter_timeseries_values(
589+
::NumSymbolCache, p::Vector{Float64}, args...)
590+
for (idx, buf) in args
591+
if idx == 1
592+
p[1:2] .= buf
593+
else
594+
p[3:4] .= buf
595+
end
596+
end
597+
598+
return p
599+
end
600+
function SciMLBase.create_parameter_timeseries_collection(s::NumSymbolCache, ps, tspan)
601+
trem = rem(tspan[1], 0.1, RoundDown)
602+
if trem > 0
603+
trem = 0.1 - trem
604+
end
605+
dea1 = DiffEqArray(Vector{Float64}[], (tspan[1] + trem):0.1:tspan[2])
606+
dea2 = DiffEqArray(Vector{Float64}[], Float64[])
607+
return ParameterTimeseriesCollection((dea1, dea2), deepcopy(ps))
608+
end
609+
function SciMLBase.get_saveable_values(::NumSymbolCache, p::Vector{Float64}, tsidx)
610+
if tsidx == 1
611+
return p[1:2]
612+
else
613+
return p[3:4]
614+
end
615+
end
616+
617+
@variables x(t) ud1(t) ud2(t) xd1(t) xd2(t)
618+
@parameters kp
619+
sc = SymbolCache([x],
620+
Dict(ud1 => 1, xd1 => 2, ud2 => 3, xd2 => 4, kp => 5),
621+
t;
622+
timeseries_parameters = Dict(
623+
ud1 => ParameterTimeseriesIndex(1, 1), xd1 => ParameterTimeseriesIndex(1, 2),
624+
ud2 => ParameterTimeseriesIndex(2, 1), xd2 => ParameterTimeseriesIndex(2, 2)))
625+
sys = NumSymbolCache(sc)
626+
627+
function f!(du, u, p, t)
628+
du .= u .* t .+ p[5] * sum(u)
629+
end
630+
fn = ODEFunction(f!; sys = sys)
631+
prob = ODEProblem(fn, [1.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0, 5.0])
632+
cb1 = PeriodicCallback(0.1; initial_affect = true, final_affect = true,
633+
save_positions = (false, false)) do integ
634+
integ.p[1:2] .+= exp(-integ.t)
635+
SciMLBase.save_discretes!(integ, 1)
636+
end
637+
function affect2!(integ)
638+
integ.p[3:4] .+= only(integ.u)
639+
SciMLBase.save_discretes!(integ, 2)
640+
end
641+
cb2 = DiscreteCallback((args...) -> true, affect2!, save_positions = (false, false),
642+
initialize = (c, u, t, integ) -> affect2!(integ))
643+
sol = solve(deepcopy(prob), Tsit5(); callback = CallbackSet(cb1, cb2))
644+
645+
ud1val = getindex.(sol.discretes.collection[1].u, 1)
646+
xd1val = getindex.(sol.discretes.collection[1].u, 2)
647+
ud2val = getindex.(sol.discretes.collection[2].u, 1)
648+
xd2val = getindex.(sol.discretes.collection[2].u, 2)
649+
650+
for (sym, timeseries_index, val, buffer, isobs, check_inference) in [(ud1,
651+
1,
652+
ud1val,
653+
zeros(length(ud1val)),
654+
false,
655+
true)
656+
([ud1, xd1],
657+
1,
658+
vcat.(ud1val,
659+
xd1val),
660+
map(
661+
_ -> zeros(2),
662+
ud1val),
663+
false,
664+
true)
665+
((ud2, xd2),
666+
2,
667+
tuple.(ud2val,
668+
xd2val),
669+
map(
670+
_ -> zeros(2),
671+
ud2val),
672+
false,
673+
true)
674+
(ud2 + xd2,
675+
2,
676+
ud2val .+
677+
xd2val,
678+
zeros(length(ud2val)),
679+
true,
680+
true)
681+
(
682+
[ud2 + xd2,
683+
ud2 * xd2],
684+
2,
685+
vcat.(
686+
ud2val .+
687+
xd2val,
688+
ud2val .*
689+
xd2val),
690+
map(
691+
_ -> zeros(2),
692+
ud2val),
693+
true,
694+
true)
695+
(
696+
(ud1 + xd1,
697+
ud1 * xd1),
698+
1,
699+
tuple.(
700+
ud1val .+
701+
xd1val,
702+
ud1val .*
703+
xd1val),
704+
map(
705+
_ -> zeros(2),
706+
ud1val),
707+
true,
708+
true)]
709+
getter = getp(sys, sym)
710+
if check_inference
711+
@inferred getter(sol)
712+
@inferred getter(deepcopy(buffer), sol)
713+
if !isobs
714+
@inferred getter(parameter_values(sol))
715+
if !(eltype(val) <: Number)
716+
@inferred getter(deepcopy(buffer[1]), parameter_values(sol))
717+
end
718+
end
719+
end
720+
721+
@test getter(sol) == val
722+
if eltype(val) <: Number
723+
target = val
724+
else
725+
target = collect.(val)
726+
end
727+
tmp = deepcopy(buffer)
728+
getter(tmp, sol)
729+
@test tmp == target
730+
731+
if !isobs
732+
@test getter(parameter_values(sol)) == val[end]
733+
if !(eltype(val) <: Number)
734+
target = collect(val[end])
735+
tmp = deepcopy(buffer)[end]
736+
getter(tmp, parameter_values(sol))
737+
@test tmp == target
738+
end
739+
end
740+
741+
for subidx in [
742+
1, CartesianIndex(2), :, rand(Bool, length(val)), rand(eachindex(val), 4), 2:5]
743+
if check_inference
744+
@inferred getter(sol, subidx)
745+
if !isa(val[subidx], Number)
746+
@inferred getter(deepcopy(buffer[subidx]), sol, subidx)
747+
end
748+
end
749+
@test getter(sol, subidx) == val[subidx]
750+
tmp = deepcopy(buffer[subidx])
751+
if val[subidx] isa Number
752+
continue
753+
end
754+
target = val[subidx]
755+
if eltype(target) <: Number
756+
target = collect(target)
757+
else
758+
target = collect.(target)
759+
end
760+
getter(tmp, sol, subidx)
761+
@test tmp == target
762+
end
763+
end
764+
765+
for sym in [
766+
[ud1, xd1, ud2],
767+
(ud2, xd1, xd2),
768+
ud1 + ud2,
769+
[ud1 + ud2, ud1 * xd1],
770+
(ud1 + ud2, ud1 * xd1)]
771+
getter = getp(sys, sym)
772+
@test_throws Exception getter(sol)
773+
@test_throws Exception getter([], sol)
774+
for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2]
775+
@test_throws Exception getter(sol, subidx)
776+
@test_throws Exception getter([], sol, subidx)
777+
end
778+
end
779+
780+
kpval = sol.prob.p[5]
781+
xval = getindex.(sol.u)
782+
783+
for (sym, val_is_timeseries, val, check_inference) in [
784+
(kp, false, kpval, true),
785+
([kp, kp], false, [kpval, kpval], true),
786+
((kp, kp), false, (kpval, kpval), true),
787+
(ud2, true, ud2val, true),
788+
([ud2, kp], true, vcat.(ud2val, kpval), false),
789+
((ud1, kp), true, tuple.(ud1val, kpval), false),
790+
([kp, x], true, vcat.(kpval, xval), false),
791+
((kp, x), true, tuple.(kpval, xval), false),
792+
(2ud2, true, 2 .* ud2val, true),
793+
([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false),
794+
((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false)
795+
]
796+
getter = getu(sys, sym)
797+
if check_inference
798+
@inferred getter(sol)
799+
end
800+
@test getter(sol) == val
801+
reference = val_is_timeseries ? val : xval
802+
for subidx in [
803+
1, CartesianIndex(2), :, rand(Bool, length(reference)),
804+
rand(eachindex(reference), 4), 2:6
805+
]
806+
if check_inference
807+
@inferred getter(sol, subidx)
808+
end
809+
target = if val_is_timeseries
810+
val[subidx]
811+
else
812+
val
813+
end
814+
@test getter(sol, subidx) == target
815+
end
816+
end
817+
818+
_xval = xval[1]
819+
_ud1val = ud1val[1]
820+
_ud2val = ud2val[1]
821+
_xd1val = xd1val[1]
822+
_xd2val = xd2val[1]
823+
integ = init(prob, Tsit5(); callback = CallbackSet(cb1, cb2))
824+
for (sym, val, check_inference) in [
825+
([x, ud1], [_xval, _ud1val], false),
826+
((x, ud1), (_xval, _ud1val), true),
827+
(x + ud2, _xval + _ud2val, true),
828+
([2x, 3xd1], [2_xval, 3_xd1val], true),
829+
((2x, 3xd2), (2_xval, 3_xd2val), true)
830+
]
831+
getter = getu(sys, sym)
832+
@test_throws Exception getter(sol)
833+
for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2]
834+
@test_throws Exception getter(sol, subidx)
835+
end
836+
837+
if check_inference
838+
@inferred getter(integ)
839+
end
840+
@test getter(integ) == val
841+
end
842+
843+
xinterp = sol(0.1:0.1:0.3, idxs = x)
844+
xinterp2 = sol(sol.discretes.collection[2].t[2:4], idxs = x)
845+
ud1interp = ud1val[2:4]
846+
ud2interp = ud2val[2:4]
847+
848+
c1 = SciMLBase.Clock(0.1)
849+
c2 = SciMLBase.SolverStepClock
850+
for (sym, t, val) in [
851+
(x, c1[2], xinterp[1]),
852+
(x, c1[2:4], xinterp),
853+
([x, ud1], c1[2], [xinterp[1], ud1interp[1]]),
854+
([x, ud1], c1[2:4], vcat.(xinterp, ud1interp)),
855+
(x, c2[2], xinterp2[1]),
856+
(x, c2[2:4], xinterp2),
857+
([x, ud2], c2[2], [xinterp2[1], ud2interp[1]]),
858+
([x, ud2], c2[2:4], vcat.(xinterp2, ud2interp))
859+
]
860+
res = sol(t, idxs = sym)
861+
if res isa DiffEqArray
862+
res = res.u
863+
end
864+
@test res == val
865+
end
866+
end

0 commit comments

Comments
 (0)