@@ -412,6 +412,11 @@ function observed_symbols(nw::Network)
412
412
return syms
413
413
end
414
414
415
+ const U_TYPE = 1
416
+ const P_TYPE = 2
417
+ const OUT_TYPE = 3
418
+ const AGG_TYPE = 4
419
+ const OBS_TYPE = 5
415
420
function SII. observed (nw:: Network , snis)
416
421
if (snis isa AbstractVector || snis isa Tuple) && any (sni -> sni isa ObservableExpression, snis)
417
422
throw (ArgumentError (" Cannot mix normal symbolic indices with @obsex currently!" ))
@@ -422,38 +427,40 @@ function SII.observed(nw::Network, snis)
422
427
isscalar = snis isa SymbolicIndex
423
428
_snis = isscalar ? (snis,) : snis
424
429
425
- # mapping i -> (:u , j) / (:p , j) / (:out , j) / (:agg , j)
426
- arraymapping = Dict {Int, Tuple{Symbol , Int}} ()
427
- # mapping i -> f(fullstate, p, t) (component observables)
428
- obsfuns = Dict {Int, Function} ()
430
+ # mapping i -> (U_TYPE , j) / (P_TYPE , j) / (OUT_TYPE , j) / (AGG_TYPE , j) / (OBS_TYPE, j in obsfuns )
431
+ arraymapping = Vector { Tuple{Int , Int}} (undef, length (_snis) )
432
+ # vector of obs functions
433
+ obsfuns = Vector { Function} ()
429
434
430
435
for (i, sni) in enumerate (_snis)
431
436
if SII. is_variable (nw, sni)
432
- arraymapping[i] = (:u , SII. variable_index (nw, sni))
437
+ arraymapping[i] = (U_TYPE , SII. variable_index (nw, sni))
433
438
elseif SII. is_parameter (nw, sni)
434
- arraymapping[i] = (:p , SII. parameter_index (nw, sni))
439
+ arraymapping[i] = (P_TYPE , SII. parameter_index (nw, sni))
435
440
else
436
441
cf = getcomp (nw, sni)
437
442
438
443
@argcheck sni. subidx isa Symbol " Observed musst be referenced by symbol, got $sni "
439
444
if (idx= findfirst (isequal (sni. subidx), outsym_flat (cf))) != nothing # output
440
445
_range = getcompoutrange (nw, sni)
441
- arraymapping[i] = (:out , _range[idx])
446
+ arraymapping[i] = (OUT_TYPE , _range[idx])
442
447
elseif (idx= findfirst (isequal (sni. subidx), obssym (cf))) != nothing # found in observed
443
448
_obsf = _get_observed_f (nw, cf, resolvecompidx (nw, sni))
444
- obsfuns[i] = let obsidx = idx # otherwise $idx is boxed everywhere in function
449
+ _newobsfun = let obsidx = idx # otherwise $idx is boxed everywhere in function
445
450
(u, outbuf, aggbuf, extbuf, p, t) -> _obsf (u, outbuf, aggbuf, extbuf, p, t)[obsidx]
446
451
end
452
+ push! (obsfuns, _newobsfun)
453
+ arraymapping[i] = (OBS_TYPE, length (obsfuns))
447
454
elseif hasinsym (cf) && sni. subidx ∈ insym_all (cf) # found in input
448
455
if sni isa SymbolicVertexIndex
449
456
idx = findfirst (isequal (sni. subidx), insym_all (cf))
450
- arraymapping[i] = (:agg , nw. im. v_aggr[resolvecompidx (nw, sni)][idx])
457
+ arraymapping[i] = (AGG_TYPE , nw. im. v_aggr[resolvecompidx (nw, sni)][idx])
451
458
elseif sni isa SymbolicEdgeIndex
452
459
edge = nw. im. edgevec[resolvecompidx (nw, sni)]
453
460
if (idx = findfirst (isequal (sni. subidx), insym (cf). src)) != nothing
454
- arraymapping[i] = (:out , nw. im. v_out[edge. src][idx])
461
+ arraymapping[i] = (OUT_TYPE , nw. im. v_out[edge. src][idx])
455
462
elseif (idx = findfirst (isequal (sni. subidx), insym (cf). dst)) != nothing
456
- arraymapping[i] = (:out , nw. im. v_out[edge. dst][idx])
463
+ arraymapping[i] = (OUT_TYPE , nw. im. v_out[edge. dst][idx])
457
464
else
458
465
error ()
459
466
end
@@ -465,46 +472,40 @@ function SII.observed(nw::Network, snis)
465
472
end
466
473
end
467
474
end
468
- needsbuf = any (m -> m[1 ] ∈ (:out , :agg ), arraymapping) || ! isempty (obsfuns)
475
+ needsbuf = any (m -> m[1 ] ∈ (OUT_TYPE, AGG_TYPE, OBS_TYPE), arraymapping)
476
+ obsfunstup = Tuple (obsfuns) # make obsfuns concretely typed
469
477
470
478
if isscalar
471
479
(u, p, t) -> begin
472
480
if needsbuf
473
481
outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs= true )
474
482
end
475
- if ! isempty (arraymapping)
476
- type, idx = only (arraymapping)
477
- type == :u && return u[idx]
478
- type == :p && return p[idx]
479
- type == :out && return outbuf[idx]
480
- type == :agg && return aggbuf[idx]
481
- else
482
- obsf = only (obsfuns). second
483
- return obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
484
- end
483
+ type, idx = only (arraymapping)
484
+ type == U_TYPE && return u[idx]
485
+ type == P_TYPE && return p[idx]
486
+ type == OUT_TYPE && return outbuf[idx]
487
+ type == AGG_TYPE && return aggbuf[idx]
488
+ type == OBS_TYPE && return only (obsfunstup)(u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
485
489
end
486
490
else
487
- # make tuple to have concretely typed obsf
488
- obsfunstup = zip (keys (obsfuns), values (obsfuns)) |> Tuple
489
491
(u, p, t, out= similar (u, length (_snis))) -> begin
490
492
if needsbuf
491
493
outbuf, aggbuf, extbuf = get_buffers (nw, u, p, t; initbufs= true )
492
494
end
493
495
494
- for (i, (type, idx)) in arraymapping
495
- if type == :u
496
+ for (i, (type, idx)) in pairs ( arraymapping)
497
+ if type == U_TYPE
496
498
out[i] = u[idx]
497
- elseif type == :p
499
+ elseif type == P_TYPE
498
500
out[i] = p[idx]
499
- elseif type == :out
501
+ elseif type == OUT_TYPE
500
502
out[i] = outbuf[idx]
501
- elseif type == :agg
503
+ elseif type == AGG_TYPE
502
504
out[i] = aggbuf[idx]
505
+ elseif type == OBS_TYPE
506
+ out[i] = obsfunstup[idx](u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
503
507
end
504
508
end
505
- for (i, obsf) in obsfunstup
506
- out[i] = obsf (u, outbuf, aggbuf, extbuf, p, t):: eltype (u)
507
- end
508
509
return out
509
510
end
510
511
end
0 commit comments