Skip to content

Commit 6c912e6

Browse files
committed
further improve performance of observed calls
1 parent 2ed1320 commit 6c912e6

File tree

2 files changed

+43
-42
lines changed

2 files changed

+43
-42
lines changed

src/symbolicindexing.jl

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,11 @@ function observed_symbols(nw::Network)
412412
return syms
413413
end
414414

415+
const U_TYPE = 1
416+
const P_TYPE = 2
417+
const OUT_TYPE = 3
418+
const AGG_TYPE = 4
419+
const OBS_TYPE = 5
415420
function SII.observed(nw::Network, snis)
416421
if (snis isa AbstractVector || snis isa Tuple) && any(sni -> sni isa ObservableExpression, snis)
417422
throw(ArgumentError("Cannot mix normal symbolic indices with @obsex currently!"))
@@ -422,38 +427,40 @@ function SII.observed(nw::Network, snis)
422427
isscalar = snis isa SymbolicIndex
423428
_snis = isscalar ? (snis,) : snis
424429

425-
# mapping i -> (:u, j) / (:p, j) / (:out, j) / (:agg, j)
426-
arraymapping = Dict{Int, Tuple{Symbol, Int}}()
427-
# mapping i -> f(fullstate, p, t) (component observables)
428-
obsfuns = Dict{Int, Function}()
430+
# mapping i -> (U_TYPE, j) / (P_TYPE, j) / (OUT_TYPE, j) / (AGG_TYPE, j) / (OBS_TYPE, j in obsfuns)
431+
arraymapping = Vector{Tuple{Int, Int}}(undef, length(_snis))
432+
# vector of obs functions
433+
obsfuns = Vector{Function}()
429434

430435
for (i, sni) in enumerate(_snis)
431436
if SII.is_variable(nw, sni)
432-
arraymapping[i] = (:u, SII.variable_index(nw, sni))
437+
arraymapping[i] = (U_TYPE, SII.variable_index(nw, sni))
433438
elseif SII.is_parameter(nw, sni)
434-
arraymapping[i] = (:p, SII.parameter_index(nw, sni))
439+
arraymapping[i] = (P_TYPE, SII.parameter_index(nw, sni))
435440
else
436441
cf = getcomp(nw, sni)
437442

438443
@argcheck sni.subidx isa Symbol "Observed musst be referenced by symbol, got $sni"
439444
if (idx=findfirst(isequal(sni.subidx), outsym_flat(cf))) != nothing # output
440445
_range = getcompoutrange(nw, sni)
441-
arraymapping[i] = (:out, _range[idx])
446+
arraymapping[i] = (OUT_TYPE, _range[idx])
442447
elseif (idx=findfirst(isequal(sni.subidx), obssym(cf))) != nothing #found in observed
443448
_obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni))
444-
obsfuns[i] = let obsidx = idx # otherwise $idx is boxed everywhere in function
449+
_newobsfun = let obsidx = idx # otherwise $idx is boxed everywhere in function
445450
(u, outbuf, aggbuf, extbuf, p, t) -> _obsf(u, outbuf, aggbuf, extbuf, p, t)[obsidx]
446451
end
452+
push!(obsfuns, _newobsfun)
453+
arraymapping[i] = (OBS_TYPE, length(obsfuns))
447454
elseif hasinsym(cf) && sni.subidx insym_all(cf) # found in input
448455
if sni isa SymbolicVertexIndex
449456
idx = findfirst(isequal(sni.subidx), insym_all(cf))
450-
arraymapping[i] = (:agg, nw.im.v_aggr[resolvecompidx(nw, sni)][idx])
457+
arraymapping[i] = (AGG_TYPE, nw.im.v_aggr[resolvecompidx(nw, sni)][idx])
451458
elseif sni isa SymbolicEdgeIndex
452459
edge = nw.im.edgevec[resolvecompidx(nw, sni)]
453460
if (idx = findfirst(isequal(sni.subidx), insym(cf).src)) != nothing
454-
arraymapping[i] = (:out, nw.im.v_out[edge.src][idx])
461+
arraymapping[i] = (OUT_TYPE, nw.im.v_out[edge.src][idx])
455462
elseif (idx = findfirst(isequal(sni.subidx), insym(cf).dst)) != nothing
456-
arraymapping[i] = (:out, nw.im.v_out[edge.dst][idx])
463+
arraymapping[i] = (OUT_TYPE, nw.im.v_out[edge.dst][idx])
457464
else
458465
error()
459466
end
@@ -465,46 +472,40 @@ function SII.observed(nw::Network, snis)
465472
end
466473
end
467474
end
468-
needsbuf = any(m -> m[1] (:out, :agg), arraymapping) || !isempty(obsfuns)
475+
needsbuf = any(m -> m[1] (OUT_TYPE, AGG_TYPE, OBS_TYPE), arraymapping)
476+
obsfunstup = Tuple(obsfuns) # make obsfuns concretely typed
469477

470478
if isscalar
471479
(u, p, t) -> begin
472480
if needsbuf
473481
outbuf, aggbuf, extbuf = get_buffers(nw, u, p, t; initbufs=true)
474482
end
475-
if !isempty(arraymapping)
476-
type, idx = only(arraymapping)
477-
type == :u && return u[idx]
478-
type == :p && return p[idx]
479-
type == :out && return outbuf[idx]
480-
type == :agg && return aggbuf[idx]
481-
else
482-
obsf = only(obsfuns).second
483-
return obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
484-
end
483+
type, idx = only(arraymapping)
484+
type == U_TYPE && return u[idx]
485+
type == P_TYPE && return p[idx]
486+
type == OUT_TYPE && return outbuf[idx]
487+
type == AGG_TYPE && return aggbuf[idx]
488+
type == OBS_TYPE && return only(obsfunstup)(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
485489
end
486490
else
487-
# make tuple to have concretely typed obsf
488-
obsfunstup = zip(keys(obsfuns), values(obsfuns)) |> Tuple
489491
(u, p, t, out=similar(u, length(_snis))) -> begin
490492
if needsbuf
491493
outbuf, aggbuf, extbuf = get_buffers(nw, u, p, t; initbufs=true)
492494
end
493495

494-
for (i, (type, idx)) in arraymapping
495-
if type == :u
496+
for (i, (type, idx)) in pairs(arraymapping)
497+
if type == U_TYPE
496498
out[i] = u[idx]
497-
elseif type == :p
499+
elseif type == P_TYPE
498500
out[i] = p[idx]
499-
elseif type == :out
501+
elseif type == OUT_TYPE
500502
out[i] = outbuf[idx]
501-
elseif type == :agg
503+
elseif type == AGG_TYPE
502504
out[i] = aggbuf[idx]
505+
elseif type == OBS_TYPE
506+
out[i] = obsfunstup[idx](u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
503507
end
504508
end
505-
for (i, obsf) in obsfunstup
506-
out[i] = obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
507-
end
508509
return out
509510
end
510511
end

test/symbolicindexing_test.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ for idx in idxtypes
308308
println(idx, " => ", b.allocs, " allocations")
309309
end
310310
if VERSION v"1.11"
311-
@test b.allocs <= 13
311+
@test b.allocs <= 7
312312
end
313313
end
314314

@@ -319,7 +319,7 @@ for idx in idxtypes
319319
getter = SII.getu(s, _idx)
320320
b = @b $(SII.getu)($s, $_idx)
321321
if b.allocs != 0
322-
println(idx, "\t=> ", b.allocs, " allocations to generate getter")
322+
println(rpad(idx,21), "=> ", b.allocs, " allocations to generate getter")
323323
end
324324
b = @b $getter($s)
325325
v = getter(s)
@@ -525,22 +525,22 @@ end
525525
idxs1 = [VIndex(1,), VIndex(2, :Pdamping), EIndex(1,:P), VIndex(2,:P)]
526526
idxs2 = [VIndex(1,), VIndex(2,)]
527527
# full call
528-
# @b $s[$idxs1] # 134 106
529-
# @b $s[$idxs2] # 31 31
528+
# @b $s[$idxs1] # 134 106 94
529+
# @b $s[$idxs2] # 31 31 34
530530

531531
# scalar call
532-
# @b $s[$(VIndex(2,:Pdamping))]
532+
# @b $s[$(VIndex(2,:Pdamping))] # 28
533533

534534
# @b SII.observed($nw, $(VIndex(2,:Pdamping))) # 15
535-
# @b SII.observed($nw, $(VIndex(2,:θ))) # 7
535+
# @b SII.observed($nw, $(VIndex(2,:θ))) # 7 5
536536

537-
b = @b SII.observed($nw, $idxs1) # 69 36 42
537+
b = @b SII.observed($nw, $idxs1) # 69 36 42 30
538538
if VERSION v"1.11"
539-
@test b.allocs <= 42
539+
@test b.allocs <= 30
540540
end
541-
b = @b SII.observed($nw, $idxs2) # 12 7 10
541+
b = @b SII.observed($nw, $idxs2) # 12 7 10 5
542542
if VERSION v"1.11"
543-
@test b.allocs <= 10
543+
@test b.allocs <= 5
544544
end
545545

546546
obsf1 = SII.observed(nw, idxs1)

0 commit comments

Comments
 (0)