Skip to content

Commit 5454834

Browse files
committed
refactor: relax type constraints to allow callable parameters in pdeps
1 parent 0916a02 commit 5454834

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

src/systems/abstractsystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3090,7 +3090,7 @@ function process_parameter_dependencies(pdeps, ps)
30903090
end
30913091
for p in pdeps]
30923092
end
3093-
lhss = BasicSymbolic[]
3093+
lhss = []
30943094
for p in pdeps
30953095
if !isparameter(p.lhs)
30963096
error("LHS of parameter dependency must be a single parameter. Found $(p.lhs).")
@@ -3101,6 +3101,7 @@ function process_parameter_dependencies(pdeps, ps)
31013101
end
31023102
push!(lhss, p.lhs)
31033103
end
3104+
lhss = map(identity, lhss)
31043105
pdeps = topsort_equations(pdeps, union(ps, lhss))
31053106
ps = filter(ps) do p
31063107
!any(isequal(p), lhss)

src/systems/index_cache.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ const UnknownIndexMap = Dict{
3939
const TunableIndexMap = Dict{BasicSymbolic,
4040
Union{Int, UnitRange{Int}, Base.ReshapedArray{Int, N, UnitRange{Int}} where {N}}}
4141

42-
struct IndexCache
42+
struct IndexCache{D}
4343
unknown_idx::UnknownIndexMap
4444
# sym => (bufferidx, idx_in_buffer)
4545
discrete_idx::Dict{BasicSymbolic, DiscreteIndex}
@@ -49,7 +49,7 @@ struct IndexCache
4949
constant_idx::ParamIndexMap
5050
nonnumeric_idx::NonnumericMap
5151
observed_syms::Set{BasicSymbolic}
52-
dependent_pars::Set{BasicSymbolic}
52+
dependent_pars::Set{D}
5353
discrete_buffer_sizes::Vector{Vector{BufferTemplate}}
5454
tunable_buffer_size::BufferTemplate
5555
constant_buffer_sizes::Vector{BufferTemplate}
@@ -275,7 +275,15 @@ function IndexCache(sys::AbstractSystem)
275275
end
276276
end
277277

278-
dependent_pars = Set{BasicSymbolic}()
278+
pdeps = parameter_dependencies(sys)
279+
280+
D = if isempty(pdeps)
281+
BasicSymbolic
282+
else
283+
mapreduce(typeof, promote_type, getproperty.(pdeps, :lhs))
284+
end
285+
dependent_pars = Set{D}()
286+
279287
for eq in parameter_dependencies(sys)
280288
sym = eq.lhs
281289
ttsym = default_toterm(sym)
@@ -289,7 +297,7 @@ function IndexCache(sys::AbstractSystem)
289297
end
290298
end
291299

292-
return IndexCache(
300+
return IndexCache{D}(
293301
unk_idxs,
294302
disc_idxs,
295303
callback_to_clocks,

0 commit comments

Comments
 (0)