Skip to content

Commit 7e8cbee

Browse files
committed
test: add tests for changing data in ParametrizedInterpolation
1 parent e3d63c2 commit 7e8cbee

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,12 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3939
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
4040
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
4141
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
42+
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
4243
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
4344
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
45+
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
46+
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4447
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4548

4649
[targets]
47-
test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations"]
50+
test = ["Aqua", "LinearAlgebra", "OrdinaryDiffEq", "Optimization", "SafeTestsets", "Test", "ControlSystemsBase", "DataInterpolations", "SciMLStructures", "SymbolicIndexingInterface"]

test/Blocks/sources.jl

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using ModelingToolkitStandardLibrary.Blocks: smooth_sin, smooth_cos, smooth_damp
66
smooth_triangular, triangular, square
77
using OrdinaryDiffEq: ReturnCode.Success
88
using DataInterpolations
9+
using SymbolicIndexingInterface
10+
using SciMLStructures: SciMLStructures, Tunable
11+
using Optimization
912

1013
@testset "Constant" begin
1114
@named src = Constant(k = 2)
@@ -479,8 +482,8 @@ end
479482

480483
@testset "ParametrizedInterpolation" begin
481484
@variables y(t) = 0
482-
@parameters u[1:15] = rand(15)
483-
@parameters x[1:15] = 0:14.0
485+
u = rand(15)
486+
x = 0:14.0
484487

485488
@testset "LinearInterpolation" begin
486489
@named i = ParametrizedInterpolation(LinearInterpolation, u, x)
@@ -493,6 +496,41 @@ end
493496
sol = solve(prob)
494497

495498
@test SciMLBase.successful_retcode(sol)
499+
500+
prob2 = remake(prob, p=[i.data => ones(15)])
501+
sol2 = solve(prob2)
502+
503+
@test SciMLBase.successful_retcode(sol2)
504+
@test all(only.(sol2.u) .≈ sol2.t) # the solution for y' = 1 is y(t) = t
505+
506+
set_data! = setp(prob2, i.data)
507+
set_data!(prob2, zeros(15))
508+
sol3 = solve(prob2)
509+
@test SciMLBase.successful_retcode(sol3)
510+
@test iszero(sol3)
511+
512+
function loss(x, p)
513+
prob0, set_data! = p
514+
ps = parameter_values(prob0)
515+
arr, repack, alias = SciMLStructures.canonicalize(Tunable(), ps)
516+
T = promote_type(eltype(x), eltype(arr))
517+
promoted_ps = SciMLStructures.replace(Tunable(), ps, T.(arr))
518+
prob = remake(prob0; p = promoted_ps)
519+
520+
set_data!(prob, x)
521+
sol = solve(prob)
522+
sum(abs2.(only.(sol.u) .- sol.t))
523+
end
524+
525+
set_data! = setp(prob, i.data)
526+
of = OptimizationFunction(loss, AutoForwardDiff())
527+
op = OptimizationProblem(of, u, (prob, set_data!), lb = zeros(15), ub = fill(2.0, 15))
528+
529+
# check that type changing works
530+
@test length(ForwardDiff.gradient(x -> of(x, (prob, set_data!)), u)) == 15
531+
532+
r = solve(op, Optimization.LBFGS(), maxiters = 1000)
533+
@test of(r.u, (prob, set_data!)) < of(u, (prob, set_data!))
496534
end
497535

498536
@testset "BSplineInterpolation" begin

0 commit comments

Comments
 (0)