@@ -422,28 +422,23 @@ function SII.observed(nw::Network, snis)
422
422
isscalar = snis isa SymbolicIndex
423
423
_snis = isscalar ? (snis,) : snis
424
424
425
- # mapping i -> index in state
426
- stateidx = Dict {Int, Int} ()
427
- # mapping i -> index in p
428
- paraidx = Dict {Int, Int} ()
429
- # mapping i -> index in output
430
- outidx = Dict {Int, Int} ()
431
- # mapping i -> index in aggbuf
432
- aggidx = Dict {Int, Int} ()
425
+ # mapping i -> (:u, j) / (:p, j) / (:out, j) / (:agg, j)
426
+ arraymapping = Dict {Int, Tuple{Symbol, Int}} ()
433
427
# mapping i -> f(fullstate, p, t) (component observables)
434
428
obsfuns = Dict {Int, Function} ()
429
+
435
430
for (i, sni) in enumerate (_snis)
436
431
if SII. is_variable (nw, sni)
437
- stateidx [i] = SII. variable_index (nw, sni)
432
+ arraymapping [i] = ( :u , SII. variable_index (nw, sni) )
438
433
elseif SII. is_parameter (nw, sni)
439
- paraidx [i] = SII. parameter_index (nw, sni)
434
+ arraymapping [i] = ( :p , SII. parameter_index (nw, sni) )
440
435
else
441
436
cf = getcomp (nw, sni)
442
437
443
438
@argcheck sni. subidx isa Symbol " Observed musst be referenced by symbol, got $sni "
444
439
if (idx= findfirst (isequal (sni. subidx), outsym_flat (cf))) != nothing # output
445
440
_range = getcompoutrange (nw, sni)
446
- outidx [i] = _range[idx]
441
+ arraymapping [i] = ( :out , _range[idx])
447
442
elseif (idx= findfirst (isequal (sni. subidx), obssym (cf))) != nothing # found in observed
448
443
_obsf = _get_observed_f (nw, cf, resolvecompidx (nw, sni))
449
444
obsfuns[i] = let obsidx = idx # otherwise $idx is boxed everywhere in function
@@ -452,13 +447,13 @@ function SII.observed(nw::Network, snis)
452
447
elseif hasinsym (cf) && sni. subidx ∈ insym_all (cf) # found in input
453
448
if sni isa SymbolicVertexIndex
454
449
idx = findfirst (isequal (sni. subidx), insym_all (cf))
455
- aggidx [i] = nw. im. v_aggr[resolvecompidx (nw, sni)][idx]
450
+ arraymapping [i] = ( :agg , nw. im. v_aggr[resolvecompidx (nw, sni)][idx])
456
451
elseif sni isa SymbolicEdgeIndex
457
452
edge = nw. im. edgevec[resolvecompidx (nw, sni)]
458
453
if (idx = findfirst (isequal (sni. subidx), insym (cf). src)) != nothing
459
- outidx [i] = nw. im. v_out[edge. src][idx]
454
+ arraymapping [i] = ( :out , nw. im. v_out[edge. src][idx])
460
455
elseif (idx = findfirst (isequal (sni. subidx), insym (cf). dst)) != nothing
461
- outidx [i] = nw. im. v_out[edge. dst][idx]
456
+ arraymapping [i] = ( :out , nw. im. v_out[edge. dst][idx])
462
457
else
463
458
error ()
464
459
end
@@ -470,52 +465,47 @@ function SII.observed(nw::Network, snis)
470
465
end
471
466
end
472
467
end
473
- initbufs = ! isempty (outidx) || ! isempty (aggidx ) || ! isempty (obsfuns)
468
+ needsbuf = any (m -> m[ 1 ] ∈ ( :out , :agg ), arraymapping ) || ! isempty (obsfuns)
474
469
475
470
if isscalar
476
471
(u, p, t) -> begin
477
- outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs)
478
- if ! isempty (stateidx)
479
- idx = only (stateidx). second
480
- u[idx]
481
- elseif ! isempty (paraidx)
482
- idx = only (paraidx). second
483
- p[idx]
484
- elseif ! isempty (outidx)
485
- idx = only (outidx). second
486
- outbuf[idx]
487
- elseif ! isempty (aggidx)
488
- idx = only (aggidx). second
489
- aggbuf[idx]
472
+ if needsbuf
473
+ outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs= true )
474
+ end
475
+ if ! isempty (arraymapping)
476
+ type, idx = only (arraymapping)
477
+ type == :u && return u[idx]
478
+ type == :p && return p[idx]
479
+ type == :out && return outbuf[idx]
480
+ type == :agg && return aggbuf[idx]
490
481
else
491
482
obsf = only (obsfuns). second
492
- obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
483
+ return obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
493
484
end
494
485
end
495
486
else
496
487
# make tuple to have concretely typed obsf
497
488
obsfunstup = zip (keys (obsfuns), values (obsfuns)) |> Tuple
498
489
(u, p, t, out= similar (u, length (_snis))) -> begin
499
- if any ( ! isempty, (outidx, aggidx, obsfuns))
500
- outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs)
490
+ if needsbuf
491
+ outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs= true )
501
492
end
502
493
503
- for (i, statei) in stateidx
504
- out[i] = u[statei]
505
- end
506
- for (i, parai) in paraidx
507
- out[i] = p[parai]
508
- end
509
- for (i, outi) in outidx
510
- out[i] = outbuf[outi]
511
- end
512
- for (i, aggi) in aggidx
513
- out[i] = aggbuf[aggi]
494
+ for (i, (type, idx)) in arraymapping
495
+ if type == :u
496
+ out[i] = u[idx]
497
+ elseif type == :p
498
+ out[i] = p[idx]
499
+ elseif type == :out
500
+ out[i] = outbuf[idx]
501
+ elseif type == :agg
502
+ out[i] = aggbuf[idx]
503
+ end
514
504
end
515
505
for (i, obsf) in obsfunstup
516
506
out[i] = obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
517
507
end
518
- out
508
+ return out
519
509
end
520
510
end
521
511
end
0 commit comments