Skip to content

Commit 77591ca

Browse files
refactor: store symbols instead of hashes in IndexMap
1 parent 2699582 commit 77591ca

File tree

3 files changed

+145
-122
lines changed

3 files changed

+145
-122
lines changed

src/systems/abstractsystem.jl

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -347,25 +347,22 @@ end
347347

348348
#Treat the result as a vector of symbols always
349349
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
350-
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
351-
return unwrap(sym) in 1:length(variable_symbols(sys))
350+
sym = unwrap(sym)
351+
if sym isa Int # [x, 1] coerces 1 to a Num
352+
return sym in 1:length(variable_symbols(sys))
352353
end
353-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
354-
ic = get_index_cache(sys)
355-
h = getsymbolhash(sym)
356-
return haskey(ic.unknown_idx, h) ||
357-
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) ||
358-
(istree(sym) && operation(sym) === getindex &&
359-
is_variable(sys, first(arguments(sym))))
360-
else
361-
return any(isequal(sym), variable_symbols(sys)) ||
362-
hasname(sym) && is_variable(sys, getname(sym))
354+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
355+
return is_variable(ic, sym) || istree(sym) && operation(sym) === getindex &&
356+
is_variable(ic, first(arguments(sym)))
363357
end
358+
return any(isequal(sym), variable_symbols(sys)) ||
359+
hasname(sym) && is_variable(sys, getname(sym))
364360
end
365361

366362
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
363+
sym = unwrap(sym)
367364
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
368-
return haskey(ic.unknown_idx, hash(sym))
365+
return is_variable(ic, sym)
369366
end
370367
return any(isequal(sym), getname.(variable_symbols(sys))) ||
371368
count('', string(sym)) == 1 &&
@@ -374,21 +371,19 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
374371
end
375372

376373
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
377-
if unwrap(sym) isa Int
378-
return unwrap(sym)
374+
sym = unwrap(sym)
375+
if sym isa Int
376+
return sym
379377
end
380-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
381-
ic = get_index_cache(sys)
382-
h = getsymbolhash(sym)
383-
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
384-
385-
h = getsymbolhash(default_toterm(sym))
386-
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
387-
sym = unwrap(sym)
388-
istree(sym) && operation(sym) === getindex || return nothing
389-
idx = variable_index(sys, first(arguments(sym)))
390-
idx === nothing && return nothing
391-
return idx[arguments(sym)[(begin + 1):end]...]
378+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
379+
return if (idx = variable_index(ic, sym)) !== nothing
380+
idx
381+
elseif istree(sym) && operation(sym) === getindex &&
382+
(idx = variable_index(ic, first(arguments(sym)))) !== nothing
383+
idx[arguments(sym)[begin + 1:end]...]
384+
else
385+
nothing
386+
end
392387
end
393388
idx = findfirst(isequal(sym), variable_symbols(sys))
394389
if idx === nothing && hasname(sym)
@@ -399,7 +394,7 @@ end
399394

400395
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
401396
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
402-
return get(ic.unknown_idx, hash(sym), nothing)
397+
return variable_index(ic, sym)
403398
end
404399
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))
405400
if idx !== nothing
@@ -418,30 +413,21 @@ function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
418413
end
419414

420415
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
416+
sym = unwrap(sym)
417+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
418+
return is_parameter(ic, sym) || istree(sym) && operation(sym) === getindex &&
419+
is_parameter(ic, first(arguments(sym)))
420+
end
421421
if unwrap(sym) isa Int
422422
return unwrap(sym) in 1:length(parameter_symbols(sys))
423423
end
424-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
425-
ic = get_index_cache(sys)
426-
h = getsymbolhash(sym)
427-
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
428-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
429-
haskey(ic.nonnumeric_idx, h)
430-
true
431-
else
432-
h = getsymbolhash(default_toterm(sym))
433-
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
434-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
435-
haskey(ic.nonnumeric_idx, h)
436-
end
437-
end
438424
return any(isequal(sym), parameter_symbols(sys)) ||
439425
hasname(sym) && is_parameter(sys, getname(sym))
440426
end
441427

442428
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
443429
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
444-
return ParameterIndex(ic, sym) !== nothing
430+
return is_parameter(ic, sym)
445431
end
446432
return any(isequal(sym), getname.(parameter_symbols(sys))) ||
447433
count('', string(sym)) == 1 &&
@@ -450,20 +436,21 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol
450436
end
451437

452438
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
453-
if unwrap(sym) isa Int
454-
return unwrap(sym)
455-
end
456-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
457-
ic = get_index_cache(sys)
458-
return if (idx = ParameterIndex(ic, sym)) !== nothing
459-
idx
460-
elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing
439+
sym = unwrap(sym)
440+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
441+
return if (idx = parameter_index(ic, sym)) !== nothing
461442
idx
443+
elseif istree(sym) && operation(sym) === getindex &&
444+
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
445+
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[begin+1:end]...))
462446
else
463447
nothing
464448
end
465449
end
466450

451+
if sym isa Int
452+
return sym
453+
end
467454
idx = findfirst(isequal(sym), parameter_symbols(sys))
468455
if idx === nothing && hasname(sym)
469456
idx = parameter_index(sys, getname(sym))
@@ -473,7 +460,7 @@ end
473460

474461
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
475462
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
476-
return ParameterIndex(ic, sym)
463+
return parameter_index(ic, sym)
477464
end
478465
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
479466
if idx !== nothing

src/systems/index_cache.jl

Lines changed: 92 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
abstract type SymbolHash end
2-
3-
function getsymbolhash(sym)
4-
sym = unwrap(sym)
5-
hasmetadata(sym, SymbolHash) ? getmetadata(sym, SymbolHash) : hash(sym)
6-
end
7-
81
struct BufferTemplate
92
type::DataType
103
length::Int
@@ -18,38 +11,38 @@ struct ParameterIndex{P, I}
1811
idx::I
1912
end
2013

21-
const IndexMap = Dict{UInt, Tuple{Int, Int}}
14+
const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
15+
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
2216

2317
struct IndexCache
24-
unknown_idx::Dict{UInt, Union{Int, UnitRange{Int}}}
25-
discrete_idx::IndexMap
26-
param_idx::IndexMap
27-
constant_idx::IndexMap
28-
dependent_idx::IndexMap
29-
nonnumeric_idx::IndexMap
18+
unknown_idx::UnknownIndexMap
19+
discrete_idx::ParamIndexMap
20+
tunable_idx::ParamIndexMap
21+
constant_idx::ParamIndexMap
22+
dependent_idx::ParamIndexMap
23+
nonnumeric_idx::ParamIndexMap
3024
discrete_buffer_sizes::Vector{BufferTemplate}
31-
param_buffer_sizes::Vector{BufferTemplate}
25+
tunable_buffer_sizes::Vector{BufferTemplate}
3226
constant_buffer_sizes::Vector{BufferTemplate}
3327
dependent_buffer_sizes::Vector{BufferTemplate}
3428
nonnumeric_buffer_sizes::Vector{BufferTemplate}
3529
end
3630

3731
function IndexCache(sys::AbstractSystem)
3832
unks = solved_unknowns(sys)
39-
unk_idxs = Dict{UInt, Union{Int, UnitRange{Int}}}()
33+
unk_idxs = UnknownIndexMap()
4034
let idx = 1
4135
for sym in unks
42-
h = getsymbolhash(sym)
36+
usym = unwrap(sym)
4337
sym_idx = if Symbolics.isarraysymbolic(sym)
4438
idx:(idx + length(sym) - 1)
4539
else
4640
idx
4741
end
48-
unk_idxs[h] = sym_idx
42+
unk_idxs[usym] = sym_idx
4943

5044
if hasname(sym)
51-
h = hash(getname(sym))
52-
unk_idxs[h] = sym_idx
45+
unk_idxs[getname(usym)] = sym_idx
5346
end
5447
idx += length(sym)
5548
end
@@ -120,17 +113,15 @@ function IndexCache(sys::AbstractSystem)
120113
end
121114

122115
function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
123-
idxs = IndexMap()
116+
idxs = ParamIndexMap()
124117
buffer_sizes = BufferTemplate[]
125118
for (i, (T, buf)) in enumerate(buffers)
126119
for (j, p) in enumerate(buf)
127-
h = getsymbolhash(p)
128-
idxs[h] = (i, j)
129-
h = getsymbolhash(default_toterm(p))
130-
idxs[h] = (i, j)
120+
idxs[p] = (i, j)
121+
idxs[default_toterm(p)] = (i, j)
131122
if hasname(p)
132-
h = hash(getname(p))
133-
idxs[h] = (i, j)
123+
idxs[getname(p)] = (i, j)
124+
idxs[getname(default_toterm(p))] = (i, j)
134125
end
135126
end
136127
push!(buffer_sizes, BufferTemplate(T, length(buf)))
@@ -139,39 +130,87 @@ function IndexCache(sys::AbstractSystem)
139130
end
140131

141132
disc_idxs, discrete_buffer_sizes = get_buffer_sizes_and_idxs(disc_buffers)
142-
param_idxs, param_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
133+
tunable_idxs, tunable_buffer_sizes = get_buffer_sizes_and_idxs(tunable_buffers)
143134
const_idxs, const_buffer_sizes = get_buffer_sizes_and_idxs(constant_buffers)
144135
dependent_idxs, dependent_buffer_sizes = get_buffer_sizes_and_idxs(dependent_buffers)
145136
nonnumeric_idxs, nonnumeric_buffer_sizes = get_buffer_sizes_and_idxs(nonnumeric_buffers)
146137

147138
return IndexCache(
148139
unk_idxs,
149140
disc_idxs,
150-
param_idxs,
141+
tunable_idxs,
151142
const_idxs,
152143
dependent_idxs,
153144
nonnumeric_idxs,
154145
discrete_buffer_sizes,
155-
param_buffer_sizes,
146+
tunable_buffer_sizes,
156147
const_buffer_sizes,
157148
dependent_buffer_sizes,
158149
nonnumeric_buffer_sizes
159150
)
160151
end
161152

153+
function SymbolicIndexingInterface.is_variable(ic::IndexCache, sym)
154+
return check_index_map(ic.unknown_idx, sym) !== nothing
155+
end
156+
157+
function SymbolicIndexingInterface.variable_index(ic::IndexCache, sym)
158+
return check_index_map(ic.unknown_idx, sym)
159+
end
160+
161+
function SymbolicIndexingInterface.is_parameter(ic::IndexCache, sym)
162+
return check_index_map(ic.tunable_idx, sym) !== nothing ||
163+
check_index_map(ic.discrete_idx, sym) !== nothing ||
164+
check_index_map(ic.constant_idx, sym) !== nothing ||
165+
check_index_map(ic.nonnumeric_idx, sym) !== nothing ||
166+
check_index_map(ic.dependent_idx, sym) !== nothing
167+
end
168+
169+
function SymbolicIndexingInterface.parameter_index(ic::IndexCache, sym)
170+
return if (idx = check_index_map(ic.tunable_idx, sym)) !== nothing
171+
ParameterIndex(SciMLStructures.Tunable(), idx)
172+
elseif (idx = check_index_map(ic.discrete_idx, sym)) !== nothing
173+
ParameterIndex(SciMLStructures.Discrete(), idx)
174+
elseif (idx = check_index_map(ic.constant_idx, sym)) !== nothing
175+
ParameterIndex(SciMLStructures.Constants(), idx)
176+
elseif (idx = check_index_map(ic.nonnumeric_idx, sym)) !== nothing
177+
ParameterIndex(NONNUMERIC_PORTION, idx)
178+
elseif (idx = check_index_map(ic.dependent_idx, sym)) !== nothing
179+
ParameterIndex(DEPENDENT_PORTION, idx)
180+
else
181+
nothing
182+
end
183+
end
184+
185+
function check_index_map(idxmap, sym)
186+
if (idx = get(idxmap, sym, nothing)) !== nothing
187+
return idx
188+
elseif hasname(sym) && (idx = get(idxmap, getname(sym), nothing)) !== nothing
189+
return idx
190+
end
191+
dsym = default_toterm(sym)
192+
isequal(sym, dsym) && return nothing
193+
if (idx = get(idxmap, dsym, nothing)) !== nothing
194+
idx
195+
elseif hasname(dsym) && (idx = get(idxmap, getname(dsym), nothing)) !== nothing
196+
idx
197+
else
198+
nothing
199+
end
200+
end
201+
162202
function ParameterIndex(ic::IndexCache, p, sub_idx = ())
163203
p = unwrap(p)
164-
h = p isa Symbol ? hash(p) : getsymbolhash(p)
165-
return if haskey(ic.param_idx, h)
166-
ParameterIndex(SciMLStructures.Tunable(), (ic.param_idx[h]..., sub_idx...))
167-
elseif haskey(ic.discrete_idx, h)
168-
ParameterIndex(SciMLStructures.Discrete(), (ic.discrete_idx[h]..., sub_idx...))
169-
elseif haskey(ic.constant_idx, h)
170-
ParameterIndex(SciMLStructures.Constants(), (ic.constant_idx[h]..., sub_idx...))
171-
elseif haskey(ic.dependent_idx, h)
172-
ParameterIndex(DEPENDENT_PORTION, (ic.dependent_idx[h]..., sub_idx...))
173-
elseif haskey(ic.nonnumeric_idx, h)
174-
ParameterIndex(NONNUMERIC_PORTION, (ic.nonnumeric_idx[h]..., sub_idx...))
204+
return if haskey(ic.tunable_idx, p)
205+
ParameterIndex(SciMLStructures.Tunable(), (ic.tunable_idx[p]..., sub_idx...))
206+
elseif haskey(ic.discrete_idx, p)
207+
ParameterIndex(SciMLStructures.Discrete(), (ic.discrete_idx[p]..., sub_idx...))
208+
elseif haskey(ic.constant_idx, p)
209+
ParameterIndex(SciMLStructures.Constants(), (ic.constant_idx[p]..., sub_idx...))
210+
elseif haskey(ic.dependent_idx, p)
211+
ParameterIndex(DEPENDENT_PORTION, (ic.dependent_idx[p]..., sub_idx...))
212+
elseif haskey(ic.nonnumeric_idx, p)
213+
ParameterIndex(NONNUMERIC_PORTION, (ic.nonnumeric_idx[p]..., sub_idx...))
175214
elseif istree(p) && operation(p) === getindex
176215
_p, sub_idx... = arguments(p)
177216
ParameterIndex(ic, _p, sub_idx)
@@ -182,7 +221,7 @@ end
182221

183222
function discrete_linear_index(ic::IndexCache, idx::ParameterIndex)
184223
idx.portion isa SciMLStructures.Discrete || error("Discrete variable index expected")
185-
ind = sum(temp.length for temp in ic.param_buffer_sizes; init = 0)
224+
ind = sum(temp.length for temp in ic.tunable_buffer_sizes; init = 0)
186225
ind += sum(
187226
temp.length for temp in Iterators.take(ic.discrete_buffer_sizes, idx.idx[1] - 1);
188227
init = 0)
@@ -202,7 +241,7 @@ end
202241

203242
function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
204243
param_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
205-
for temp in ic.param_buffer_sizes)
244+
for temp in ic.tunable_buffer_sizes)
206245
disc_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
207246
for temp in ic.discrete_buffer_sizes)
208247
const_buf = Tuple(BasicSymbolic[unwrap(variable(:DEF)) for _ in 1:(temp.length)]
@@ -213,21 +252,20 @@ function reorder_parameters(ic::IndexCache, ps; drop_missing = false)
213252
for temp in ic.nonnumeric_buffer_sizes)
214253

215254
for p in ps
216-
h = getsymbolhash(p)
217-
if haskey(ic.discrete_idx, h)
218-
i, j = ic.discrete_idx[h]
255+
if haskey(ic.discrete_idx, p)
256+
i, j = ic.discrete_idx[p]
219257
disc_buf[i][j] = unwrap(p)
220-
elseif haskey(ic.param_idx, h)
221-
i, j = ic.param_idx[h]
258+
elseif haskey(ic.tunable_idx, p)
259+
i, j = ic.tunable_idx[p]
222260
param_buf[i][j] = unwrap(p)
223-
elseif haskey(ic.constant_idx, h)
224-
i, j = ic.constant_idx[h]
261+
elseif haskey(ic.constant_idx, p)
262+
i, j = ic.constant_idx[p]
225263
const_buf[i][j] = unwrap(p)
226-
elseif haskey(ic.dependent_idx, h)
227-
i, j = ic.dependent_idx[h]
264+
elseif haskey(ic.dependent_idx, p)
265+
i, j = ic.dependent_idx[p]
228266
dep_buf[i][j] = unwrap(p)
229-
elseif haskey(ic.nonnumeric_idx, h)
230-
i, j = ic.nonnumeric_idx[h]
267+
elseif haskey(ic.nonnumeric_idx, p)
268+
i, j = ic.nonnumeric_idx[p]
231269
nonnumeric_buf[i][j] = unwrap(p)
232270
else
233271
error("Invalid parameter $p")

0 commit comments

Comments
 (0)