Skip to content

Commit 4346094

Browse files
committed
test: add type inference tests
1 parent 2da4f8a commit 4346094

File tree

2 files changed

+74
-1
lines changed

2 files changed

+74
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
119119
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
120120
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
121121
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
122+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
122123
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
123124
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
124125
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -137,4 +138,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
137138
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
138139

139140
[targets]
140-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg"]
141+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]

test/mtkparameters.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters
33
using SymbolicIndexingInterface
44
using SciMLStructures: SciMLStructures, canonicalize, Tunable, Discrete, Constants
55
using ForwardDiff
6+
using JET
67

78
@parameters a b c d::Integer e[1:3] f[1:3, 1:3]::Int g::Vector{AbstractFloat} h::String
89
@named sys = ODESystem(
@@ -121,3 +122,74 @@ ps = MTKParameters(sys, [p => 1.0, q => 2.0, r => 3.0])
121122
newps = remake_buffer(sys, ps, Dict(p => 1.0f0))
122123
@test newps.tunable[1] isa Vector{Float32}
123124
@test newps.tunable[1] == [1.0f0, 2.0f0, 3.0f0]
125+
126+
# JET tests
127+
128+
# scalar parameters only
129+
function level1()
130+
@parameters p1=0.5 [tunable = true] p2 = 1 [tunable=true] p3 = 3 [tunable = false] p4=3 [tunable = true] y0=1
131+
@variables x(t)=2 y(t)=y0
132+
D = Differential(t)
133+
134+
eqs = [D(x) ~ p1 * x - p2 * x * y
135+
D(y) ~ -p3 * y + p4 * x * y]
136+
137+
sys = structural_simplify(complete(ODESystem(
138+
eqs, t, tspan = (0, 3.0), name = :sys, parameter_dependencies = [y0 => 2p4])))
139+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys)
140+
end
141+
142+
# scalar and vector parameters
143+
function level2()
144+
@parameters p1=0.5 [tunable = true] (p23[1:2]=[1, 3.0]) [tunable = true] p4=3 [tunable = false] y0=1
145+
@variables x(t)=2 y(t)=y0
146+
D = Differential(t)
147+
148+
eqs = [D(x) ~ p1 * x - p23[1] * x * y
149+
D(y) ~ -p23[2] * y + p4 * x * y]
150+
151+
sys = structural_simplify(complete(ODESystem(
152+
eqs, t, tspan = (0, 3.0), name = :sys, parameter_dependencies = [y0 => 2p4])))
153+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys)
154+
end
155+
156+
# scalar and vector parameters with different scalar types
157+
function level3()
158+
@parameters p1=0.5 [tunable = true] (p23[1:2]=[1, 3.0]) [tunable = true] p4::Int=3 [tunable = true] y0::Int=1
159+
@variables x(t)=2 y(t)=y0
160+
D = Differential(t)
161+
162+
eqs = [D(x) ~ p1 * x - p23[1] * x * y
163+
D(y) ~ -p23[2] * y + p4 * x * y]
164+
165+
sys = structural_simplify(complete(ODESystem(
166+
eqs, t, tspan = (0, 3.0), name = :sys, parameter_dependencies = [y0 => 2p4])))
167+
prob = ODEProblem{true, SciMLBase.FullSpecialize}(sys)
168+
end
169+
170+
@testset "level$i" for (i, prob) in enumerate([level1(), level2(), level3()])
171+
ps = prob.p
172+
@testset "Type stability of $portion" for portion in [Tunable(), Discrete(), Constants()]
173+
@test_call canonicalize(portion, ps)
174+
# @inferred canonicalize(portion, ps)
175+
broken =
176+
(i [2,3] && portion == Tunable())
177+
178+
# broken because the size of a vector of vectors can't be determined at compile time
179+
@test_opt broken=broken target_modules = (ModelingToolkit,) canonicalize(
180+
portion, ps)
181+
182+
buffer, repack, alias = canonicalize(portion, ps)
183+
184+
@test_call SciMLStructures.replace(portion, ps, ones(length(buffer)))
185+
@inferred SciMLStructures.replace(portion, ps, ones(length(buffer)))
186+
@test_opt target_modules=(ModelingToolkit,) SciMLStructures.replace(
187+
portion, ps, ones(length(buffer)))
188+
189+
@test_call target_modules = (ModelingToolkit,) SciMLStructures.replace!(
190+
portion, ps, ones(length(buffer)))
191+
@inferred SciMLStructures.replace!(portion, ps, ones(length(buffer)))
192+
@test_opt target_modules=(ModelingToolkit,) SciMLStructures.replace!(
193+
portion, ps, ones(length(buffer)))
194+
end
195+
end

0 commit comments

Comments
 (0)