Skip to content

Commit 0136759

Browse files
fixup! feat: support callable parameters
1 parent 6d98bff commit 0136759

File tree

4 files changed

+16
-7
lines changed

4 files changed

+16
-7
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
2424
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
2525
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
2626
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
27+
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
2728
FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf"
2829
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
2930
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -91,6 +92,7 @@ ExprTools = "0.1.10"
9192
Expronicon = "0.8"
9293
FindFirstFunctions = "1"
9394
ForwardDiff = "0.10.3"
95+
FunctionWrappers = "1.1"
9496
FunctionWrappersWrappers = "0.1"
9597
Graphs = "1.5.2"
9698
InteractiveUtils = "1"

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ using Base: RefValue
3838
using Combinatorics
3939
import Distributions
4040
import FunctionWrappersWrappers
41+
import FunctionWrappers: FunctionWrapper
4142
using URIs: URI
4243
using SciMLStructures
4344
using Compat

src/systems/index_cache.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@ function IndexCache(sys::AbstractSystem)
109109
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
110110
nonnumeric_buffers = Dict{Any, Set{Union{BasicSymbolic, CallWithMetadata}}}()
111111

112-
function insert_by_type!(buffers::Dict{Any, S}, sym) where {S}
112+
function insert_by_type!(buffers::Dict{Any, S}, sym, ctype) where {S}
113113
sym = unwrap(sym)
114-
ctype = symtype(sym)
115114
buf = get!(buffers, ctype, S())
116115
push!(buf, sym)
117116
end
@@ -144,7 +143,7 @@ function IndexCache(sys::AbstractSystem)
144143
clocks = get!(() -> Set{Int}(), disc_param_callbacks, sym)
145144
push!(clocks, i)
146145
else
147-
insert_by_type!(constant_buffers, sym)
146+
insert_by_type!(constant_buffers, sym, symtype(sym))
148147
end
149148
end
150149
end
@@ -199,6 +198,9 @@ function IndexCache(sys::AbstractSystem)
199198
for p in parameters(sys)
200199
p = unwrap(p)
201200
ctype = symtype(p)
201+
if ctype <: FnType
202+
ctype = fntype_to_function_type(ctype)
203+
end
202204
haskey(disc_idxs, p) && continue
203205
haskey(constant_buffers, ctype) && p in constant_buffers[ctype] && continue
204206
insert_by_type!(
@@ -214,7 +216,8 @@ function IndexCache(sys::AbstractSystem)
214216
else
215217
nonnumeric_buffers
216218
end,
217-
p
219+
p,
220+
ctype,
218221
)
219222
end
220223

@@ -485,3 +488,6 @@ function get_buffer_template(ic::IndexCache, pidx::ParameterIndex)
485488
error("Unhandled portion $portion")
486489
end
487490
end
491+
492+
fntype_to_function_type(::Type{FnType{A, R, T}}) where {A, R, T} = T
493+
fntype_to_function_type(::Type{FnType{A, R, Nothing}}) where {A, R} = FunctionWrapper{R, A}

src/systems/parameter_buffer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ symconvert(::Type{Symbolics.Struct{T}}, x) where {T} = convert(T, x)
22
symconvert(::Type{T}, x) where {T} = convert(T, x)
33
symconvert(::Type{Real}, x::Integer) = convert(Float64, x)
44
symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x))
5-
function symconvert(::Type{T}, x) where {T <: FnType}
6-
isempty(methods(x)) ? error("Expected value $x to be a callable") : x
7-
end
85

96
struct MTKParameters{T, D, C, N}
107
tunable::T
@@ -155,6 +152,9 @@ function MTKParameters(
155152
if symbolic_type(val) !== NotSymbolic()
156153
error("Could not evaluate value of parameter $sym. Missing values for variables in expression $val.")
157154
end
155+
if ctype <: FnType
156+
ctype = fntype_to_function_type(ctype)
157+
end
158158
val = symconvert(ctype, val)
159159
done = set_value(sym, val)
160160
if !done && Symbolics.isarraysymbolic(sym)

0 commit comments

Comments
 (0)