Skip to content

Commit 9968b46

Browse files
committed
test: add tests for changing data in ParametrizedInterpolation
1 parent 4557dd9 commit 9968b46

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
4040
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
4141
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
4242
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
43+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
4344
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
4445
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
4546
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
47+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
48+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4649
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4750

4851
[targets]
49-
test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "OrdinaryDiffEqDefault", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations"]
52+
test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "OrdinaryDiffEqDefault", "Optimization", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations", "SciMLStructures", "SymbolicIndexingInterface"]

test/Blocks/sources.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using ModelingToolkitStandardLibrary.Blocks: smooth_sin, smooth_cos, smooth_damp
66
smooth_triangular, triangular, square
77
using OrdinaryDiffEq: ReturnCode.Success
88
using DataInterpolations
9+
using SymbolicIndexingInterface
10+
using SciMLStructures: SciMLStructures, Tunable
11+
using Optimization
912

1013
@testset "Constant" begin
1114
@named src = Constant(k = 2)
@@ -480,8 +483,8 @@ end
480483

481484
@testset "ParametrizedInterpolation" begin
482485
@variables y(t) = 0
483-
@parameters u[1:15] = rand(15)
484-
@parameters x[1:15] = 0:14.0
486+
u = rand(15)
487+
x = 0:14.0
485488

486489
@testset "LinearInterpolation" begin
487490
@named i = ParametrizedInterpolation(LinearInterpolation, u, x)
@@ -494,6 +497,41 @@ end
494497
sol = solve(prob)
495498

496499
@test SciMLBase.successful_retcode(sol)
500+
501+
prob2 = remake(prob, p=[i.data => ones(15)])
502+
sol2 = solve(prob2)
503+
504+
@test SciMLBase.successful_retcode(sol2)
505+
@test all(only.(sol2.u) .≈ sol2.t) # the solution for y' = 1 is y(t) = t
506+
507+
set_data! = setp(prob2, i.data)
508+
set_data!(prob2, zeros(15))
509+
sol3 = solve(prob2)
510+
@test SciMLBase.successful_retcode(sol3)
511+
@test iszero(sol3)
512+
513+
function loss(x, p)
514+
prob0, set_data! = p
515+
ps = parameter_values(prob0)
516+
arr, repack, alias = SciMLStructures.canonicalize(Tunable(), ps)
517+
T = promote_type(eltype(x), eltype(arr))
518+
promoted_ps = SciMLStructures.replace(Tunable(), ps, T.(arr))
519+
prob = remake(prob0; p = promoted_ps)
520+
521+
set_data!(prob, x)
522+
sol = solve(prob)
523+
sum(abs2.(only.(sol.u) .- sol.t))
524+
end
525+
526+
set_data! = setp(prob, i.data)
527+
of = OptimizationFunction(loss, AutoForwardDiff())
528+
op = OptimizationProblem(of, u, (prob, set_data!), lb = zeros(15), ub = fill(2.0, 15))
529+
530+
# check that type changing works
531+
@test length(ForwardDiff.gradient(x -> of(x, (prob, set_data!)), u)) == 15
532+
533+
r = solve(op, Optimization.LBFGS(), maxiters = 1000)
534+
@test of(r.u, (prob, set_data!)) < of(u, (prob, set_data!))
497535
end
498536

499537
@testset "BSplineInterpolation" begin

0 commit comments

Comments
 (0)