@@ -422,28 +422,23 @@ function SII.observed(nw::Network, snis)
422422 isscalar = snis isa SymbolicIndex
423423 _snis = isscalar ? (snis,) : snis
424424
425- # mapping i -> index in state
426- stateidx = Dict {Int, Int} ()
427- # mapping i -> index in p
428- paraidx = Dict {Int, Int} ()
429- # mapping i -> index in output
430- outidx = Dict {Int, Int} ()
431- # mapping i -> index in aggbuf
432- aggidx = Dict {Int, Int} ()
425+ # mapping i -> (:u, j) / (:p, j) / (:out, j) / (:agg, j)
426+ arraymapping = Dict {Int, Tuple{Symbol, Int}} ()
433427 # mapping i -> f(fullstate, p, t) (component observables)
434428 obsfuns = Dict {Int, Function} ()
429+
435430 for (i, sni) in enumerate (_snis)
436431 if SII. is_variable (nw, sni)
437- stateidx [i] = SII. variable_index (nw, sni)
432+ arraymapping [i] = ( :u , SII. variable_index (nw, sni) )
438433 elseif SII. is_parameter (nw, sni)
439- paraidx [i] = SII. parameter_index (nw, sni)
434+ arraymapping [i] = ( :p , SII. parameter_index (nw, sni) )
440435 else
441436 cf = getcomp (nw, sni)
442437
443438 @argcheck sni. subidx isa Symbol " Observed musst be referenced by symbol, got $sni "
444439 if (idx= findfirst (isequal (sni. subidx), outsym_flat (cf))) != nothing # output
445440 _range = getcompoutrange (nw, sni)
446- outidx [i] = _range[idx]
441+ arraymapping [i] = ( :out , _range[idx])
447442 elseif (idx= findfirst (isequal (sni. subidx), obssym (cf))) != nothing # found in observed
448443 _obsf = _get_observed_f (nw, cf, resolvecompidx (nw, sni))
449444 obsfuns[i] = let obsidx = idx # otherwise $idx is boxed everywhere in function
@@ -452,13 +447,13 @@ function SII.observed(nw::Network, snis)
452447 elseif hasinsym (cf) && sni. subidx ∈ insym_all (cf) # found in input
453448 if sni isa SymbolicVertexIndex
454449 idx = findfirst (isequal (sni. subidx), insym_all (cf))
455- aggidx [i] = nw. im. v_aggr[resolvecompidx (nw, sni)][idx]
450+ arraymapping [i] = ( :agg , nw. im. v_aggr[resolvecompidx (nw, sni)][idx])
456451 elseif sni isa SymbolicEdgeIndex
457452 edge = nw. im. edgevec[resolvecompidx (nw, sni)]
458453 if (idx = findfirst (isequal (sni. subidx), insym (cf). src)) != nothing
459- outidx [i] = nw. im. v_out[edge. src][idx]
454+ arraymapping [i] = ( :out , nw. im. v_out[edge. src][idx])
460455 elseif (idx = findfirst (isequal (sni. subidx), insym (cf). dst)) != nothing
461- outidx [i] = nw. im. v_out[edge. dst][idx]
456+ arraymapping [i] = ( :out , nw. im. v_out[edge. dst][idx])
462457 else
463458 error ()
464459 end
@@ -470,52 +465,47 @@ function SII.observed(nw::Network, snis)
470465 end
471466 end
472467 end
473- initbufs = ! isempty (outidx) || ! isempty (aggidx ) || ! isempty (obsfuns)
468+ needsbuf = any (m -> m[ 1 ] ∈ ( :out , :agg ), arraymapping ) || ! isempty (obsfuns)
474469
475470 if isscalar
476471 (u, p, t) -> begin
477- outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs)
478- if ! isempty (stateidx)
479- idx = only (stateidx). second
480- u[idx]
481- elseif ! isempty (paraidx)
482- idx = only (paraidx). second
483- p[idx]
484- elseif ! isempty (outidx)
485- idx = only (outidx). second
486- outbuf[idx]
487- elseif ! isempty (aggidx)
488- idx = only (aggidx). second
489- aggbuf[idx]
472+ if needsbuf
473+ outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs= true )
474+ end
475+ if ! isempty (arraymapping)
476+ type, idx = only (arraymapping)
477+ type == :u && return u[idx]
478+ type == :p && return p[idx]
479+ type == :out && return outbuf[idx]
480+ type == :agg && return aggbuf[idx]
490481 else
491482 obsf = only (obsfuns). second
492- obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
483+ return obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
493484 end
494485 end
495486 else
496487 # make tuple to have concretely typed obsf
497488 obsfunstup = zip (keys (obsfuns), values (obsfuns)) |> Tuple
498489 (u, p, t, out= similar (u, length (_snis))) -> begin
499- if any ( ! isempty, (outidx, aggidx, obsfuns))
500- outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs)
490+ if needsbuf
491+ outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs= true )
501492 end
502493
503- for (i, statei) in stateidx
504- out[i] = u[statei]
505- end
506- for (i, parai) in paraidx
507- out[i] = p[parai]
508- end
509- for (i, outi) in outidx
510- out[i] = outbuf[outi]
511- end
512- for (i, aggi) in aggidx
513- out[i] = aggbuf[aggi]
494+ for (i, (type, idx)) in arraymapping
495+ if type == :u
496+ out[i] = u[idx]
497+ elseif type == :p
498+ out[i] = p[idx]
499+ elseif type == :out
500+ out[i] = outbuf[idx]
501+ elseif type == :agg
502+ out[i] = aggbuf[idx]
503+ end
514504 end
515505 for (i, obsf) in obsfunstup
516506 out[i] = obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
517507 end
518- out
508+ return out
519509 end
520510 end
521511end
0 commit comments