Skip to content

Commit 75fd953

Browse files
feat: support callable parameters
1 parent 9611e18 commit 75fd953

File tree

5 files changed

+47
-13
lines changed

5 files changed

+47
-13
lines changed

src/ModelingToolkit.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ using Symbolics: _parse_vars, value, @derivatives, get_variables,
6363
VariableSource, getname, variable, Connection, connect,
6464
NAMESPACE_SEPARATOR, set_scalar_metadata, setdefaultval,
6565
initial_state, transition, activeState, entry, hasnode,
66-
ticksInState, timeInState, fixpoint_sub, fast_substitute
66+
ticksInState, timeInState, fixpoint_sub, fast_substitute,
67+
CallWithMetadata
6768
const NAMESPACE_SEPARATOR_SYMBOL = Symbol(NAMESPACE_SEPARATOR)
6869
import Symbolics: rename, get_variables!, _solve, hessian_sparsity,
6970
jacobian_sparsity, isaffine, islinear, _iszero, _isone,

src/systems/index_cache.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ struct DiscreteIndex
3232
end
3333

3434
const ParamIndexMap = Dict{BasicSymbolic, Tuple{Int, Int}}
35+
const NonnumericMap = Dict{
36+
Union{BasicSymbolic, Symbolics.CallWithMetadata}, Tuple{Int, Int}}
3537
const UnknownIndexMap = Dict{
3638
BasicSymbolic, Union{Int, UnitRange{Int}, AbstractArray{Int}}}
3739
const TunableIndexMap = Dict{BasicSymbolic,
@@ -45,20 +47,20 @@ struct IndexCache
4547
callback_to_clocks::Dict{Any, Vector{Int}}
4648
tunable_idx::TunableIndexMap
4749
constant_idx::ParamIndexMap
48-
nonnumeric_idx::ParamIndexMap
50+
nonnumeric_idx::NonnumericMap
4951
observed_syms::Set{BasicSymbolic}
5052
dependent_pars::Set{BasicSymbolic}
5153
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5254
tunable_buffer_size::BufferTemplate
5355
constant_buffer_sizes::Vector{BufferTemplate}
5456
nonnumeric_buffer_sizes::Vector{BufferTemplate}
55-
symbol_to_variable::Dict{Symbol, BasicSymbolic}
57+
symbol_to_variable::Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}
5658
end
5759

5860
function IndexCache(sys::AbstractSystem)
5961
unks = solved_unknowns(sys)
6062
unk_idxs = UnknownIndexMap()
61-
symbol_to_variable = Dict{Symbol, BasicSymbolic}()
63+
symbol_to_variable = Dict{Symbol, Union{BasicSymbolic, CallWithMetadata}}()
6264

6365
let idx = 1
6466
for sym in unks
@@ -105,12 +107,12 @@ function IndexCache(sys::AbstractSystem)
105107

106108
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
107109
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
108-
nonnumeric_buffers = Dict{Any, Set{BasicSymbolic}}()
110+
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
109111

110-
function insert_by_type!(buffers::Dict{Any, Set{BasicSymbolic}}, sym)
112+
function insert_by_type!(buffers::Dict{Any, S}, sym) where {S}
111113
sym = unwrap(sym)
112114
ctype = symtype(sym)
113-
buf = get!(buffers, ctype, Set{BasicSymbolic}())
115+
buf = get!(buffers, ctype, S())
114116
push!(buf, sym)
115117
end
116118

@@ -216,8 +218,8 @@ function IndexCache(sys::AbstractSystem)
216218
)
217219
end
218220

219-
function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
220-
idxs = ParamIndexMap()
221+
function get_buffer_sizes_and_idxs(T, buffers::Dict)
222+
idxs = T()
221223
buffer_sizes = BufferTemplate[]
222224
for (i, (T, buf)) in enumerate(buffers)
223225
for (j, p) in enumerate(buf)
@@ -229,13 +231,18 @@ function IndexCache(sys::AbstractSystem)
229231
idxs[rp] = (i, j)
230232
idxs[rttp] = (i, j)
231233
end
234+
if T <: Symbolics.FnType
235+
T = Any
236+
end
232237
push!(buffer_sizes, BufferTemplate(T, length(buf)))
233238
end
234239
return idxs, buffer_sizes
235240
end
236241

237-
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
238-
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
242+
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(
243+
ParamIndexMap, constant_buffers)
244+
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(
245+
NonnumericMap, nonnumeric_buffers)
239246

240247
tunable_idxs = TunableIndexMap()
241248
tunable_buffer_size = 0
@@ -397,7 +404,8 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
397404
for temp in ic.discrete_buffer_sizes)
398405
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
399406
for temp in ic.constant_buffer_sizes)
400-
nonnumeric_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
407+
nonnumeric_buf = Tuple(Union{BasicSymbolic, CallWithMetadata}[unwrap(variable(:DEF))
408+
for _ in 1:(temp.length)]
401409
for temp in ic.nonnumeric_buffer_sizes)
402410
for p in ps
403411
p = unwrap(p)

src/systems/parameter_buffer.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x)
22
symconvert(::Type{T}, x) where {T} = convert(T, x)
33
symconvert(::Type{Real}, x::Integer) = convert(Float64, x)
44
symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x))
5+
function symconvert(::Type{T}, x) where {T <: FnType}
6+
isempty(methods(x)) ? error("Expected value $x to be a callable") : x
7+
end
58

69
struct MTKParameters{T, D, C, N}
710
tunable::T

src/systems/systemstructure.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,8 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
661661
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
662662
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
663663
end
664-
ps = [setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
664+
ps = [sym isa CallWithMetadata ? sym :
665+
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
665666
for sym in get_ps(sys)]
666667
@set! sys.ps = ps
667668
else

test/split_parameters.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,24 @@ S = get_sensitivity(closed_loop, :u)
219219
@test ps[ParameterIndex(Tunable(), 1:8)] == collect(1.0:8.0) .+ 0.5
220220
@test ps[ParameterIndex(Discrete(), (2, 1, 2, 2))] == 5
221221
end
222+
223+
@testset "Callable parameters" begin
224+
_f1(x) = 2x
225+
struct Foo end
226+
(::Foo)(x) = 3x
227+
@variables x(t)
228+
@parameters fn(..) = _f1
229+
@mtkbuild sys = ODESystem(D(x) ~ fn(x), t, [x], [fn])
230+
@test is_parameter(sys, fn)
231+
@test ModelingToolkit.defaults(sys)[fn] == _f1
232+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0))
233+
@test_broken @inferred prob.ps[fn]
234+
@test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1])
235+
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
236+
@test sol.u[end][] exp(2.0)
237+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [fn => Foo()])
238+
@test_broken @inferred prob.ps[fn]
239+
@test_broken @inferred prob.f(prob.u0, prob.p, prob.tspan[1])
240+
sol = solve(prob; abstol = 1e-10, reltol = 1e-10)
241+
@test sol.u[end][] exp(3.0)
242+
end

0 commit comments

Comments
 (0)