Skip to content

Commit ec96855

Browse files
committed
allow observing of inputs
1 parent b4f95c7 commit ec96855

File tree

6 files changed

+65
-13
lines changed

6 files changed

+65
-13
lines changed

src/component_functions.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,9 @@ edges it returns a named tuple `(; src, dst)` with two symbol vectors.
487487
insym(c::VertexModel)::Vector{Symbol} = c.insym
488488
insym(c::EdgeModel)::@NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}} = c.insym
489489

490+
insym_all(c::VertexModel) = c.insym
491+
insym_all(c::EdgeModel) = Iterators.flatten(values(c.insym))
492+
490493
"""
491494
indim(c::VertexModel)::Int
492495
indim(c::EdgeModel)::@NamedTuple{src::Int,dst::Int}
@@ -950,16 +953,16 @@ function _fill_defaults(T, @nospecialize(kwargs))
950953
####
951954
#### Cached outsymflat/outsymall
952955
####
953-
_outsym_flat = if T <: VertexModel
954-
outsym
955-
elseif T <: EdgeModel
956-
vcat(outsym.src, outsym.dst)
957-
else
958-
error()
959-
end
956+
_outsym_flat = flatten_sym(outsym)
960957
dict[:_outsym_flat] = _outsym_flat
958+
961959
dict[:_obssym_all] = setdiff(_outsym_flat, sym) obssym
962960

961+
if !isnothing(insym)
962+
insym_flat = flatten_sym(insym)
963+
dict[:_obssym_all] = dict[:_obssym_all] insym_flat
964+
end
965+
963966
####
964967
#### External Inputs
965968
####

src/symbolicindexing.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ function SII.observed(nw::Network, snis)
412412
stateidx = Dict{Int, Int}()
413413
# mapping i -> index in output
414414
outidx = Dict{Int, Int}()
415+
# mapping i -> index in aggbuf
416+
aggidx = Dict{Int, Int}()
415417
# mapping i -> f(fullstate, p, t) (component observables)
416418
obsfuns = Dict{Int, Function}()
417419
for (i, sni) in enumerate(_snis)
@@ -427,12 +429,28 @@ function SII.observed(nw::Network, snis)
427429
elseif (idx=findfirst(isequal(sni.subidx), obssym(cf))) != nothing #found in observed
428430
_obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni))
429431
obsfuns[i] = (u, outbuf, aggbuf, extbuf, p, t) -> _obsf(u, outbuf, aggbuf, extbuf, p, t)[idx]
432+
elseif hasinsym(cf) && sni.subidx insym_all(cf) # found in input
433+
if sni isa SymbolicVertexIndex
434+
idx = findfirst(isequal(sni.subidx), insym_all(cf))
435+
aggidx[i] = nw.im.v_aggr[resolvecompidx(nw, sni)][idx]
436+
elseif sni isa SymbolicEdgeIndex
437+
edge = nw.im.edgevec[resolvecompidx(nw, sni)]
438+
if (idx = findfirst(isequal(sni.subidx), insym(cf).src)) != nothing
439+
outidx[i] = nw.im.v_out[edge.src][idx]
440+
elseif (idx = findfirst(isequal(sni.subidx), insym(cf).dst)) != nothing
441+
outidx[i] = nw.im.v_out[edge.dst][idx]
442+
else
443+
error()
444+
end
445+
else
446+
error()
447+
end
430448
else
431449
throw(ArgumentError("Cannot resolve observable $sni"))
432450
end
433451
end
434452
end
435-
initbufs = !isempty(outidx) || !isempty(obsfuns)
453+
initbufs = !isempty(outidx) || !isempty(aggidx) || !isempty(obsfuns)
436454

437455
if isscalar
438456
@closure (u, p, t) -> begin
@@ -443,6 +461,9 @@ function SII.observed(nw::Network, snis)
443461
elseif !isempty(outidx)
444462
idx = only(outidx).second
445463
outbuf[idx]
464+
elseif !isempty(aggidx)
465+
idx = only(aggidx).second
466+
aggbuf[idx]
446467
else
447468
obsf = only(obsfuns).second
448469
obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
@@ -459,6 +480,9 @@ function SII.observed(nw::Network, snis)
459480
for (i, outi) in outidx
460481
out[i] = outbuf[outi]
461482
end
483+
for (i, aggi) in aggidx
484+
out[i] = aggbuf[aggi]
485+
end
462486
for (i, obsf) in obsfuns
463487
out[i] = obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
464488
end

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ rand_inputs_fg(cf) = rand_inputs_fg(Random.default_rng(), cf)
134134
abstract type SymbolicIndex{C,S} end
135135
abstract type SymbolicStateIndex{C,S} <: SymbolicIndex{C,S} end
136136
abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end
137+
138+
flatten_sym(v::NamedTuple) = reduce(vcat, values(v))
139+
flatten_sym(v::AbstractVector{Symbol}) = v

test/ComponentLibrary.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,19 @@ diffusion_vertex() = VertexModel(f=diffusionvertex!, dim=1, g=1:1)
4747
Base.@propagate_inbounds function kuramoto_edge!(e, θ_s, θ_d, (K,), t)
4848
e .= K .* sin(θ_s[1] - θ_d[1])
4949
end
50-
function kuramoto_edge(; name=:kuramoto_edge)
50+
function kuramoto_edge(; name=:kuramoto_edge, kwargs...)
5151
EdgeModel(;g=AntiSymmetric(kuramoto_edge!),
52-
outsym=[:P], psym=[:K], name)
52+
outsym=[:P], psym=[:K], name, kwargs...)
5353
end
5454

5555
Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t)
5656
M, D, Pm = p
5757
dv[1] = v[2]
5858
dv[2] = 1 / M * (Pm - D * v[2] + acc[1])
5959
end
60-
function kuramoto_second(; name=:kuramoto_second)
60+
function kuramoto_second(; name=:kuramoto_second, kwargs...)
6161
VertexModel(; f=kuramoto_inertia!, sym=[=>0, =>0],
62-
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name)
62+
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name, kwargs...)
6363
end
6464

6565
Base.@propagate_inbounds function kuramoto_vertex!(dθ, θ, esum, (ω,), t)

test/symbolicindexing_test.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ for idx in idxtypes
306306
println(idx, " => ", b.allocs, " allocations")
307307
end
308308
if VERSION v"1.11"
309-
@test b.allocs <= 12
309+
@test b.allocs <= 13
310310
end
311311
end
312312

@@ -470,3 +470,21 @@ nw = Network(g, [n1, n2, n3], [e1, e2])
470470
@test s.p.e[:e2, 1] == s[EPIndex(2,1)]
471471
@test s.p.e[:e3, 1] == s[EPIndex(3,1)]
472472
end
473+
474+
# test observed for inputs
475+
@testset "test observing of model input" begin
476+
v1 = Lib.kuramoto_second(name=:v1, vidx=1, insym=[:Pin])
477+
v2 = Lib.kuramoto_second(name=:v2, vidx=2, insym=[:Pin])
478+
v3 = Lib.kuramoto_second(name=:v3, vidx=3, insym=[:Pin])
479+
e1 = Lib.kuramoto_edge(name=:e1, src=1, dst=2, insym=[:δin])
480+
e2 = Lib.kuramoto_edge(name=:e2, src=2, dst=3, insym=[:δin])
481+
nw = Network([v1,v2,v3], [e1,e2])
482+
s = NWState(nw, rand(dim(nw)), rand(pdim(nw)))
483+
@test s[VIndex(:v1, :Pin)] == s[EIndex(:e1, :₋P)]
484+
@test s[VIndex(:v2, :Pin)] == s[EIndex(:e1, :P)] + s[EIndex(:e2, :₋P)]
485+
@test s[VIndex(:v3, :Pin)] == s[EIndex(:e2, :P)]
486+
@test s[EIndex(:e1, :src₊δin)] == s[VIndex(:v1, )]
487+
@test s[EIndex(:e1, :dst₊δin)] == s[VIndex(:v2, )]
488+
@test s[EIndex(:e2, :src₊δin)] == s[VIndex(:v2, )]
489+
@test s[EIndex(:e2, :dst₊δin)] == s[VIndex(:v3, )]
490+
end

test/testutils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using CUDA
2+
using Adapt
3+
using NetworkDynamics: iscudacompatible, NaiveAggregator
4+
15
"""
26
Test utility, which rebuilds the Network with all different execution styles and compares the
37
results of the coreloop.

0 commit comments

Comments
 (0)