|  | 
| 1 | 1 | using ModelingToolkit, Test | 
| 2 | 2 | using ModelingToolkitStandardLibrary.Blocks | 
| 3 | 3 | using OrdinaryDiffEq | 
|  | 4 | +using DataInterpolations | 
| 4 | 5 | using BlockArrays: BlockedArray | 
| 5 | 6 | using ModelingToolkit: t_nounits as t, D_nounits as D | 
| 6 | 7 | using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION | 
| @@ -222,24 +223,44 @@ S = get_sensitivity(closed_loop, :u) | 
| 222 | 223 | end | 
| 223 | 224 | 
 | 
| 224 | 225 | @testset "Callable parameters" begin | 
| 225 |  | -    _f1(x) = 2x | 
| 226 |  | -    struct Foo end | 
| 227 |  | -    (::Foo)(x) = 3x | 
| 228 |  | -    @variables x(t) | 
| 229 |  | -    @parameters fn(..) = _f1 | 
| 230 |  | -    @mtkbuild sys = ODESystem(D(x) ~ fn(x), t, [x], [fn]) | 
| 231 |  | -    @test is_parameter(sys, fn) | 
| 232 |  | -    @test ModelingToolkit.defaults(sys)[fn] == _f1 | 
| 233 |  | - | 
| 234 |  | -    prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) | 
| 235 |  | -    @test_broken @inferred prob.ps[fn] | 
| 236 |  | -    @test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1]) | 
| 237 |  | -    sol = solve(prob; abstol = 1e-10, reltol = 1e-10) | 
| 238 |  | -    @test sol.u[end][] ≈ exp(2.0) | 
| 239 |  | - | 
| 240 |  | -    prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()]) | 
| 241 |  | -    @test_broken @inferred prob.ps[fn] | 
| 242 |  | -    @test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1]) | 
| 243 |  | -    sol = solve(prob; abstol = 1e-10, reltol = 1e-10) | 
| 244 |  | -    @test sol.u[end][] ≈ exp(3.0) | 
|  | 226 | +    @testset "As FunctionWrapper" begin | 
|  | 227 | +        _f1(x) = 2x | 
|  | 228 | +        struct Foo end | 
|  | 229 | +        (::Foo)(x) = 3x | 
|  | 230 | +        @variables x(t) | 
|  | 231 | +        @parameters fn(::Real) = _f1 | 
|  | 232 | +        @mtkbuild sys = ODESystem(D(x) ~ fn(t), t) | 
|  | 233 | +        @test is_parameter(sys, fn) | 
|  | 234 | +        @test ModelingToolkit.defaults(sys)[fn] == _f1 | 
|  | 235 | + | 
|  | 236 | +        getter = getp(sys, fn) | 
|  | 237 | +        prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0)) | 
|  | 238 | +        @inferred getter(prob) | 
|  | 239 | +        # cannot be inferred better since `FunctionWrapper` is only known to return `Real` | 
|  | 240 | +        @inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1]) | 
|  | 241 | +        sol = solve(prob, Tsit5(); abstol = 1e-10, reltol = 1e-10) | 
|  | 242 | +        @test sol.u[end][] ≈ 2.0 | 
|  | 243 | + | 
|  | 244 | +        prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()]) | 
|  | 245 | +        @inferred getter(prob) | 
|  | 246 | +        @inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1]) | 
|  | 247 | +        sol = solve(prob; abstol = 1e-10, reltol = 1e-10) | 
|  | 248 | +        @test sol.u[end][] ≈ 2.5 | 
|  | 249 | +    end | 
|  | 250 | + | 
|  | 251 | +    @testset "Concrete function type" begin | 
|  | 252 | +        ts = 0.0:0.1:1.0 | 
|  | 253 | +        interp = LinearInterpolation(ts .^ 2, ts; extrapolate = true) | 
|  | 254 | +        @variables x(t) | 
|  | 255 | +        @parameters (fn::typeof(interp))(..) | 
|  | 256 | +        @mtkbuild sys = ODESystem(D(x) ~ fn(x), t) | 
|  | 257 | +        @test is_parameter(sys, fn) | 
|  | 258 | +        getter = getp(sys, fn) | 
|  | 259 | +        prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => interp]) | 
|  | 260 | +        @inferred getter(prob) | 
|  | 261 | +        @inferred prob.f(prob.u0, prob.p, prob.tspan[1]) | 
|  | 262 | +        @test_nowarn sol = solve(prob, Tsit5()) | 
|  | 263 | +        @test_nowarn prob.ps[fn] = LinearInterpolation(ts .^ 3, ts; extrapolate = true) | 
|  | 264 | +        @test_nowarn sol = solve(prob) | 
|  | 265 | +    end | 
| 245 | 266 | end | 
0 commit comments