Skip to content

Commit 90e6398

Browse files
Merge pull request #2995 from AayushSabharwal/as/callable-params
feat: support callable parameters
2 parents 9aadc71 + 7579312 commit 90e6398

File tree

9 files changed

+111
-23
lines changed

9 files changed

+111
-23
lines changed

Project.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
2424
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
2525
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
2626
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
27+
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
2728
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
2829
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
2930
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -76,6 +77,7 @@ ChainRulesCore = "1"
7677
Combinatorics = "1"
7778
Compat = "3.42, 4"
7879
ConstructionBase = "1"
80+
DataInterpolations = "6.4"
7981
DataStructures = "0.17, 0.18"
8082
DeepDiffs = "1"
8183
DiffEqBase = "6.103.0"
@@ -91,6 +93,7 @@ ExprTools = "0.1.10"
9193
Expronicon = "0.8"
9294
FindFirstFunctions = "1"
9395
ForwardDiff = "0.10.3"
96+
FunctionWrappers = "1.1"
9497
FunctionWrappersWrappers = "0.1"
9598
Graphs = "1.5.2"
9699
InteractiveUtils = "1"
@@ -118,8 +121,8 @@ SparseArrays = "1"
118121
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
119122
StaticArrays = "0.10, 0.11, 0.12, 1.0"
120123
SymbolicIndexingInterface = "0.3.29"
121-
SymbolicUtils = "3.2"
122-
Symbolics = "6.3"
124+
SymbolicUtils = "3.7"
125+
Symbolics = "6.12"
123126
URIs = "1"
124127
UnPack = "0.1, 1.0"
125128
Unitful = "1.1"
@@ -129,6 +132,7 @@ julia = "1.9"
129132
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
130133
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
131134
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
135+
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
132136
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
133137
DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
134138
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -154,4 +158,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
154158
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
155159

156160
[targets]
157-
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
161+
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]

src/ModelingToolkit.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ using Base: RefValue
3838
using Combinatorics
3939
import Distributions
4040
import FunctionWrappersWrappers
41+
import FunctionWrappers: FunctionWrapper
4142
using URIs: URI
4243
using SciMLStructures
4344
using Compat
@@ -63,7 +64,8 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
6364
VariableSource, getname, variable, Connection, connect,
6465
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
6566
initial_state, transition, activeState, entry, hasnode,
66-
ticksInState, timeInState, fixpoint_sub, fast_substitute
67+
ticksInState, timeInState, fixpoint_sub, fast_substitute,
68+
CallWithMetadata, CallWithParent
6769
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
6870
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
6971
jacobian_sparsity, isaffine, islinear, _iszero, _isone,

src/parameters.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ function isparameter(x)
2626
end
2727
end
2828

29+
function iscalledparameter(x)
30+
x = unwrap(x)
31+
return isparameter(getmetadata(x, CallWithParent, nothing))
32+
end
33+
34+
function getcalledparameter(x)
35+
x = unwrap(x)
36+
return getmetadata(x, CallWithParent)
37+
end
38+
2939
"""
3040
toparam(s)
3141

src/systems/index_cache.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct DiscreteIndex
3232
end
3333

3434
const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
35+
const NonnumericMap = Dict{
36+
Union{BasicSymbolic, Symbolics.CallWithMetadata}, Tuple{Int, Int}}
3537
const UnknownIndexMap = Dict{
3638
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
3739
const TunableIndexMap = Dict{BasicSymbolic,
@@ -45,20 +47,20 @@ struct IndexCache
4547
callback_to_clocks::Dict{Any, Vector{Int}}
4648
tunable_idx::TunableIndexMap
4749
constant_idx::ParamIndexMap
48-
nonnumeric_idx::ParamIndexMap
50+
nonnumeric_idx::NonnumericMap
4951
observed_syms::Set{BasicSymbolic}
5052
dependent_pars::Set{BasicSymbolic}
5153
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5254
tunable_buffer_size::BufferTemplate
5355
constant_buffer_sizes::Vector{BufferTemplate}
5456
nonnumeric_buffer_sizes::Vector{BufferTemplate}
55-
symbol_to_variable::Dict{Symbol, BasicSymbolic}
57+
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
5658
end
5759

5860
function IndexCache(sys::AbstractSystem)
5961
unks = solved_unknowns(sys)
6062
unk_idxs = UnknownIndexMap()
61-
symbol_to_variable = Dict{Symbol, BasicSymbolic}()
63+
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()
6264

6365
let idx = 1
6466
for sym in unks
@@ -105,12 +107,11 @@ function IndexCache(sys::AbstractSystem)
105107

106108
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
107109
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
108-
nonnumeric_buffers = Dict{Any, Set{BasicSymbolic}}()
110+
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
109111

110-
function insert_by_type!(buffers::Dict{Any, Set{BasicSymbolic}}, sym)
112+
function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
111113
sym = unwrap(sym)
112-
ctype = symtype(sym)
113-
buf = get!(buffers, ctype, Set{BasicSymbolic}())
114+
buf = get!(buffers, ctype, S())
114115
push!(buf, sym)
115116
end
116117

@@ -142,7 +143,7 @@ function IndexCache(sys::AbstractSystem)
142143
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
143144
push!(clocks, i)
144145
else
145-
insert_by_type!(constant_buffers, sym)
146+
insert_by_type!(constant_buffers, sym, symtype(sym))
146147
end
147148
end
148149
end
@@ -197,6 +198,9 @@ function IndexCache(sys::AbstractSystem)
197198
for p in parameters(sys)
198199
p = unwrap(p)
199200
ctype = symtype(p)
201+
if ctype <: FnType
202+
ctype = fntype_to_function_type(ctype)
203+
end
200204
haskey(disc_idxs, p) && continue
201205
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
202206
insert_by_type!(
@@ -212,12 +216,13 @@ function IndexCache(sys::AbstractSystem)
212216
else
213217
nonnumeric_buffers
214218
end,
215-
p
219+
p,
220+
ctype
216221
)
217222
end
218223

219-
function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
220-
idxs = ParamIndexMap()
224+
function get_buffer_sizes_and_idxs(T, buffers::Dict)
225+
idxs = T()
221226
buffer_sizes = BufferTemplate[]
222227
for (i, (T, buf)) in enumerate(buffers)
223228
for (j, p) in enumerate(buf)
@@ -229,13 +234,18 @@ function IndexCache(sys::AbstractSystem)
229234
idxs[rp] = (i, j)
230235
idxs[rttp] = (i, j)
231236
end
237+
if T <: Symbolics.FnType
238+
T = Any
239+
end
232240
push!(buffer_sizes, BufferTemplate(T, length(buf)))
233241
end
234242
return idxs, buffer_sizes
235243
end
236244

237-
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
238-
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
245+
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(
246+
ParamIndexMap, constant_buffers)
247+
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(
248+
NonnumericMap, nonnumeric_buffers)
239249

240250
tunable_idxs = TunableIndexMap()
241251
tunable_buffer_size = 0
@@ -401,7 +411,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
401411
for temp in ic.discrete_buffer_sizes)
402412
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
403413
for temp in ic.constant_buffer_sizes)
404-
nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
414+
nonnumeric_buf = Tuple(Union{BasicSymbolic, CallWithMetadata}[unwrap(variable(:DEF))
415+
for _ in 1:(temp.length)]
405416
for temp in ic.nonnumeric_buffer_sizes)
406417
for p in ps
407418
p = unwrap(p)
@@ -481,3 +492,7 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
481492
error("Unhandled portion $portion")
482493
end
483494
end
495+
496+
fntype_to_function_type(::Type{FnType{A, R, T}}) where {A, R, T} = T
497+
fntype_to_function_type(::Type{FnType{A, R, Nothing}}) where {A, R} = FunctionWrapper{R, A}
498+
fntype_to_function_type(::Type{FnType{A, R}}) where {A, R} = FunctionWrapper{R, A}

src/systems/parameter_buffer.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ function MTKParameters(
152152
if symbolic_type(val) !== NotSymbolic()
153153
error("Could not evaluate value of parameter $sym. Missing values for variables in expression $val.")
154154
end
155+
if ctype <: FnType
156+
ctype = fntype_to_function_type(ctype)
157+
end
155158
val = symconvert(ctype, val)
156159
done = set_value(sym, val)
157160
if !done && Symbolics.isarraysymbolic(sym)

src/systems/systemstructure.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,8 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
661661
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
662662
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
663663
end
664-
ps = [setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
664+
ps = [sym isa CallWithMetadata ? sym :
665+
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
665666
for sym in get_ps(sys)]
666667
@set! sys.ps = ps
667668
else

src/utils.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,11 @@ end
371371
vars(exprs::Num; op = Differential) = vars(unwrap(exprs); op)
372372
vars(exprs::Symbolics.Arr; op = Differential) = vars(unwrap(exprs); op)
373373
function vars(exprs; op = Differential)
374-
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
374+
if hasmethod(iterate, Tuple{typeof(exprs)})
375+
foldl((x, y) -> vars!(x, unwrap(y); op = op), exprs; init = Set())
376+
else
377+
vars!(Set(), unwrap(exprs); op)
378+
end
375379
end
376380
vars(eq::Equation; op = Differential) = vars!(Set(), eq; op = op)
377381
function vars!(vars, eq::Equation; op = Differential)
@@ -479,7 +483,11 @@ end
479483

480484
function collect_var!(unknowns, parameters, var, iv)
481485
isequal(var, iv) && return nothing
482-
if isparameter(var) || (iscall(var) && isparameter(operation(var)))
486+
if iscalledparameter(var)
487+
callable = getcalledparameter(var)
488+
push!(parameters, callable)
489+
collect_vars!(unknowns, parameters, arguments(var), iv)
490+
elseif isparameter(var) || (iscall(var) && isparameter(operation(var)))
483491
push!(parameters, var)
484492
elseif !isconstant(var)
485493
push!(unknowns, var)

test/odesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,15 @@ eqs = [D(x) ~ σ(t - 1) * (y - x),
137137
D(y) ~ x *- z) - y,
138138
D(z) ~ x * y - β * z * κ]
139139
@named de = ODESystem(eqs, t)
140-
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ(t - 1), ρ, β))
140+
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ, ρ, β))
141141
f = eval(generate_function(de, [x, y, z], [σ, ρ, β])[2])
142142
du = [0.0, 0.0, 0.0]
143143
f(du, [1.0, 2.0, 3.0], [x -> x + 7, 2, 3], 5.0)
144144
@test du [11, -3, -7]
145145

146146
eqs = [D(x) ~ x + 10σ(t - 1) + 100σ(t - 2) + 1000σ(t^2)]
147147
@named de = ODESystem(eqs, t)
148-
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ(t - 2), σ(t^2), σ(t - 1)))
148+
test_diffeq_inference("many internal iv-varying", de, t, (x,), (σ,))
149149
f = eval(generate_function(de, [x], [σ])[2])
150150
du = [0.0]
151151
f(du, [1.0], [t -> t + 2], 5.0)

test/split_parameters.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
using ModelingToolkit, Test
22
using ModelingToolkitStandardLibrary.Blocks
33
using OrdinaryDiffEq
4+
using DataInterpolations
45
using BlockArrays: BlockedArray
56
using ModelingToolkit: t_nounits as t, D_nounits as D
67
using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION
78
using SciMLStructures: Tunable, Discrete, Constants
89
using StaticArrays: SizedVector
10+
using SymbolicIndexingInterface: is_parameter, getp
911

1012
x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]
1113

@@ -219,3 +221,46 @@ S = get_sensitivity(closed_loop, :u)
219221
@test ps[ParameterIndex(Tunable(), 1:8)] == collect(1.0:8.0) .+ 0.5
220222
@test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] == 5
221223
end
224+
225+
@testset "Callable parameters" begin
226+
@testset "As FunctionWrapper" begin
227+
_f1(x) = 2x
228+
struct Foo end
229+
(::Foo)(x) = 3x
230+
@variables x(t)
231+
@parameters fn(::Real) = _f1
232+
@mtkbuild sys = ODESystem(D(x) ~ fn(t), t)
233+
@test is_parameter(sys, fn)
234+
@test ModelingToolkit.defaults(sys)[fn] == _f1
235+
236+
getter = getp(sys, fn)
237+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
238+
@inferred getter(prob)
239+
# cannot be inferred better since `FunctionWrapper` is only known to return `Real`
240+
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
241+
sol = solve(prob, Tsit5(); abstol = 1e-10, reltol = 1e-10)
242+
@test sol.u[end][] 2.0
243+
244+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
245+
@inferred getter(prob)
246+
@inferred Vector{<:Real} prob.f(prob.u0, prob.p, prob.tspan[1])
247+
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
248+
@test sol.u[end][] 2.5
249+
end
250+
251+
@testset "Concrete function type" begin
252+
ts = 0.0:0.1:1.0
253+
interp = LinearInterpolation(ts .^ 2, ts; extrapolate = true)
254+
@variables x(t)
255+
@parameters (fn::typeof(interp))(..)
256+
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t)
257+
@test is_parameter(sys, fn)
258+
getter = getp(sys, fn)
259+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => interp])
260+
@inferred getter(prob)
261+
@inferred prob.f(prob.u0, prob.p, prob.tspan[1])
262+
@test_nowarn sol = solve(prob, Tsit5())
263+
@test_nowarn prob.ps[fn] = LinearInterpolation(ts .^ 3, ts; extrapolate = true)
264+
@test_nowarn sol = solve(prob)
265+
end
266+
end

0 commit comments

Comments
 (0)