Skip to content

Commit ab2788f

Browse files
fixup! feat: support callable parameters
1 parent 6d2aed8 commit ab2788f

File tree

2 files changed

+44
-21
lines changed

2 files changed

+44
-21
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ ChainRulesCore = "1"
7777
Combinatorics = "1"
7878
Compat = "3.42, 4"
7979
ConstructionBase = "1"
80+
DataInterpolations = "6.4"
8081
DataStructures = "0.17, 0.18"
8182
DeepDiffs = "1"
8283
DiffEqBase = "6.103.0"
@@ -131,6 +132,7 @@ julia = "1.9"
131132
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
132133
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
133134
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
135+
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
134136
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
135137
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
136138
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -156,4 +158,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
156158
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
157159

158160
[targets]
159-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
161+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]

test/split_parameters.jl

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using ModelingToolkit, Test
22
using ModelingToolkitStandardLibrary.Blocks
33
using OrdinaryDiffEq
4+
using DataInterpolations
45
using BlockArrays: BlockedArray
56
using ModelingToolkit: t_nounits as t, D_nounits as D
67
using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION
@@ -222,24 +223,44 @@ S = get_sensitivity(closed_loop, :u)
222223
end
223224

224225
@testset "Callable parameters" begin
225-
_f1(x) = 2x
226-
struct Foo end
227-
(::Foo)(x) = 3x
228-
@variables x(t)
229-
@parameters fn(..) = _f1
230-
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t, [x], [fn])
231-
@test is_parameter(sys, fn)
232-
@test ModelingToolkit.defaults(sys)[fn] == _f1
233-
234-
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
235-
@test_broken @inferred prob.ps[fn]
236-
@test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1])
237-
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
238-
@test sol.u[end][] exp(2.0)
239-
240-
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
241-
@test_broken @inferred prob.ps[fn]
242-
@test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1])
243-
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
244-
@test sol.u[end][] exp(3.0)
226+
@testset "As FunctionWrapper" begin
227+
_f1(x) = 2x
228+
struct Foo end
229+
(::Foo)(x) = 3x
230+
@variables x(t)
231+
@parameters fn(::Real) = _f1
232+
@mtkbuild sys = ODESystem(D(x) ~ fn(t), t)
233+
@test is_parameter(sys, fn)
234+
@test ModelingToolkit.defaults(sys)[fn] == _f1
235+
236+
getter = getp(sys, fn)
237+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
238+
@inferred getter(prob)
239+
# cannot be inferred better since `FunctionWrapper` is only known to return `Real`
240+
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
241+
sol = solve(prob, Tsit5(); abstol = 1e-10, reltol = 1e-10)
242+
@test sol.u[end][] 2.0
243+
244+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
245+
@inferred getter(prob)
246+
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
247+
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
248+
@test sol.u[end][] 2.5
249+
end
250+
251+
@testset "Concrete function type" begin
252+
ts = 0.0:0.1:1.0
253+
interp = LinearInterpolation(ts .^ 2, ts; extrapolate = true)
254+
@variables x(t)
255+
@parameters (fn::typeof(interp))(..)
256+
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t)
257+
@test is_parameter(sys, fn)
258+
getter = getp(sys, fn)
259+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => interp])
260+
@inferred getter(prob)
261+
@inferred prob.f(prob.u0, prob.p, prob.tspan[1])
262+
@test_nowarn sol = solve(prob, Tsit5())
263+
@test_nowarn prob.ps[fn] = LinearInterpolation(ts .^ 3, ts; extrapolate = true)
264+
@test_nowarn sol = solve(prob)
265+
end
245266
end

0 commit comments

Comments
 (0)