Skip to content

Commit 2ed1320

Browse files
committed
improve observed performance by combining dicts
1 parent 2cb9486 commit 2ed1320

File tree

1 file changed

+33
-43
lines changed

1 file changed

+33
-43
lines changed

src/symbolicindexing.jl

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -422,28 +422,23 @@ function SII.observed(nw::Network, snis)
422422
isscalar = snis isa SymbolicIndex
423423
_snis = isscalar ? (snis,) : snis
424424

425-
# mapping i -> index in state
426-
stateidx = Dict{Int, Int}()
427-
# mapping i -> index in p
428-
paraidx = Dict{Int, Int}()
429-
# mapping i -> index in output
430-
outidx = Dict{Int, Int}()
431-
# mapping i -> index in aggbuf
432-
aggidx = Dict{Int, Int}()
425+
# mapping i -> (:u, j) / (:p, j) / (:out, j) / (:agg, j)
426+
arraymapping = Dict{Int, Tuple{Symbol, Int}}()
433427
# mapping i -> f(fullstate, p, t) (component observables)
434428
obsfuns = Dict{Int, Function}()
429+
435430
for (i, sni) in enumerate(_snis)
436431
if SII.is_variable(nw, sni)
437-
stateidx[i] = SII.variable_index(nw, sni)
432+
arraymapping[i] = (:u, SII.variable_index(nw, sni))
438433
elseif SII.is_parameter(nw, sni)
439-
paraidx[i] = SII.parameter_index(nw, sni)
434+
arraymapping[i] = (:p, SII.parameter_index(nw, sni))
440435
else
441436
cf = getcomp(nw, sni)
442437

443438
@argcheck sni.subidx isa Symbol "Observed musst be referenced by symbol, got $sni"
444439
if (idx=findfirst(isequal(sni.subidx), outsym_flat(cf))) != nothing # output
445440
_range = getcompoutrange(nw, sni)
446-
outidx[i] = _range[idx]
441+
arraymapping[i] = (:out, _range[idx])
447442
elseif (idx=findfirst(isequal(sni.subidx), obssym(cf))) != nothing #found in observed
448443
_obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni))
449444
obsfuns[i] = let obsidx = idx # otherwise $idx is boxed everywhere in function
@@ -452,13 +447,13 @@ function SII.observed(nw::Network, snis)
452447
elseif hasinsym(cf) && sni.subidx insym_all(cf) # found in input
453448
if sni isa SymbolicVertexIndex
454449
idx = findfirst(isequal(sni.subidx), insym_all(cf))
455-
aggidx[i] = nw.im.v_aggr[resolvecompidx(nw, sni)][idx]
450+
arraymapping[i] = (:agg, nw.im.v_aggr[resolvecompidx(nw, sni)][idx])
456451
elseif sni isa SymbolicEdgeIndex
457452
edge = nw.im.edgevec[resolvecompidx(nw, sni)]
458453
if (idx = findfirst(isequal(sni.subidx), insym(cf).src)) != nothing
459-
outidx[i] = nw.im.v_out[edge.src][idx]
454+
arraymapping[i] = (:out, nw.im.v_out[edge.src][idx])
460455
elseif (idx = findfirst(isequal(sni.subidx), insym(cf).dst)) != nothing
461-
outidx[i] = nw.im.v_out[edge.dst][idx]
456+
arraymapping[i] = (:out, nw.im.v_out[edge.dst][idx])
462457
else
463458
error()
464459
end
@@ -470,52 +465,47 @@ function SII.observed(nw::Network, snis)
470465
end
471466
end
472467
end
473-
initbufs = !isempty(outidx) || !isempty(aggidx) || !isempty(obsfuns)
468+
needsbuf = any(m -> m[1] (:out, :agg), arraymapping) || !isempty(obsfuns)
474469

475470
if isscalar
476471
(u, p, t) -> begin
477-
outbuf, aggbuf, extbuf = get_buffers(nw, u, p, t; initbufs)
478-
if !isempty(stateidx)
479-
idx = only(stateidx).second
480-
u[idx]
481-
elseif !isempty(paraidx)
482-
idx = only(paraidx).second
483-
p[idx]
484-
elseif !isempty(outidx)
485-
idx = only(outidx).second
486-
outbuf[idx]
487-
elseif !isempty(aggidx)
488-
idx = only(aggidx).second
489-
aggbuf[idx]
472+
if needsbuf
473+
outbuf, aggbuf, extbuf = get_buffers(nw, u, p, t; initbufs=true)
474+
end
475+
if !isempty(arraymapping)
476+
type, idx = only(arraymapping)
477+
type == :u && return u[idx]
478+
type == :p && return p[idx]
479+
type == :out && return outbuf[idx]
480+
type == :agg && return aggbuf[idx]
490481
else
491482
obsf = only(obsfuns).second
492-
obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
483+
return obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
493484
end
494485
end
495486
else
496487
# make tuple to have concretely typed obsf
497488
obsfunstup = zip(keys(obsfuns), values(obsfuns)) |> Tuple
498489
(u, p, t, out=similar(u, length(_snis))) -> begin
499-
if any(!isempty, (outidx, aggidx, obsfuns))
500-
outbuf, aggbuf, extbuf = get_buffers(nw, u, p, t; initbufs)
490+
if needsbuf
491+
outbuf, aggbuf, extbuf = get_buffers(nw, u, p, t; initbufs=true)
501492
end
502493

503-
for (i, statei) in stateidx
504-
out[i] = u[statei]
505-
end
506-
for (i, parai) in paraidx
507-
out[i] = p[parai]
508-
end
509-
for (i, outi) in outidx
510-
out[i] = outbuf[outi]
511-
end
512-
for (i, aggi) in aggidx
513-
out[i] = aggbuf[aggi]
494+
for (i, (type, idx)) in arraymapping
495+
if type == :u
496+
out[i] = u[idx]
497+
elseif type == :p
498+
out[i] = p[idx]
499+
elseif type == :out
500+
out[i] = outbuf[idx]
501+
elseif type == :agg
502+
out[i] = aggbuf[idx]
503+
end
514504
end
515505
for (i, obsf) in obsfunstup
516506
out[i] = obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
517507
end
518-
out
508+
return out
519509
end
520510
end
521511
end

0 commit comments

Comments
 (0)