Skip to content

Commit 5c0e97b

Browse files
authored
Merge pull request #189 from JuliaDynamics/hw/observeinputs
observe inputs
2 parents b6b7533 + fe84b34 commit 5c0e97b

File tree

8 files changed

+79
-22
lines changed

8 files changed

+79
-22
lines changed

ext/MTKExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ For a given system and name, extract all the relevant meta we want to keep for t
189189
function _get_metadata(sys, name)
190190
nt = (;)
191191
sym = try
192-
getproperty_symbolic(sys, name)
192+
getproperty_symbolic(sys, name; might_contain_toplevel_ns=false)
193193
catch e
194194
if !endswith(string(name), "ˍt") # known for "internal" derivatives
195195
@warn "Could not extract metadata for $name $(e.msg)"

ext/MTKUtils.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,24 @@ function _collect_differentials!(found, ex)
9393
end
9494

9595
"""
96-
getproperty_symbolic(sys, var)
96+
getproperty_symbolic(sys, var; might_contain_toplevel_ns=true)
9797
9898
Like `getproperty` but works on a greater varaity of "var"
9999
- var can be Num or Symbolic (resolved using genname)
100-
- strip namespace of sys if present
100+
- strip namespace of sys if present (don't strip if `might_contain_top_level_ns=false`)
101101
- for nested variables (foo₊bar₊baz) resolve them one by one
102102
"""
103-
function getproperty_symbolic(sys, var)
103+
function getproperty_symbolic(sys, var; might_contain_toplevel_ns=true)
104104
ns = string(getname(sys))
105105
varname = string(getname(var))
106-
varname_nons = replace(varname, r"^"*ns*"" => "")
107-
parts = split(varname_nons, "")
106+
# split of the toplevel namespace if necessary
107+
if might_contain_toplevel_ns && startswith(varname, ns*"")
108+
if getname(sys) getname.(ModelingToolkit.get_systems(sys))
109+
@warn "Namespace :$ns appears multiple times, this might lead to unexpected, since it is not clear whether the namespace should be stripped or not."
110+
end
111+
varname = replace(varname, r"^"*ns*"" => "")
112+
end
113+
parts = split(varname, "")
108114
r = getproperty(sys, Symbol(parts[1]); namespace=false)
109115
for part in parts[2:end]
110116
r = getproperty(r, Symbol(part); namespace=true)

src/component_functions.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,9 @@ edges it returns a named tuple `(; src, dst)` with two symbol vectors.
487487
insym(c::VertexModel)::Vector{Symbol} = c.insym
488488
insym(c::EdgeModel)::@NamedTuple{src::Vector{Symbol},dst::Vector{Symbol}} = c.insym
489489

490+
insym_all(c::VertexModel) = c.insym
491+
insym_all(c::EdgeModel) = Iterators.flatten(values(c.insym))
492+
490493
"""
491494
indim(c::VertexModel)::Int
492495
indim(c::EdgeModel)::@NamedTuple{src::Int,dst::Int}
@@ -950,16 +953,17 @@ function _fill_defaults(T, @nospecialize(kwargs))
950953
####
951954
#### Cached outsymflat/outsymall
952955
####
953-
_outsym_flat = if T <: VertexModel
954-
outsym
955-
elseif T <: EdgeModel
956-
vcat(outsym.src, outsym.dst)
957-
else
958-
error()
959-
end
956+
_outsym_flat = flatten_sym(outsym)
960957
dict[:_outsym_flat] = _outsym_flat
958+
961959
dict[:_obssym_all] = setdiff(_outsym_flat, sym) obssym
962960

961+
insym = dict[:insym]
962+
if !isnothing(insym)
963+
insym_flat = flatten_sym(insym)
964+
dict[:_obssym_all] = dict[:_obssym_all] insym_flat
965+
end
966+
963967
####
964968
#### External Inputs
965969
####
@@ -991,10 +995,8 @@ function _fill_defaults(T, @nospecialize(kwargs))
991995

992996
_is = if isnothing(__is)
993997
Symbol[]
994-
elseif __is isa NamedTuple
995-
vcat(__is.src, __is.dst)
996998
else
997-
__is
999+
flatten_sym(insym)
9981000
end
9991001
if !allunique(vcat(_s, _ps, _obss, _is, _os))
10001002
throw(ArgumentError("Symbol names must be unique. There are clashes in sym, psym, outsym, obssym and insym."))

src/symbolicindexing.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ function SII.observed(nw::Network, snis)
412412
stateidx = Dict{Int, Int}()
413413
# mapping i -> index in output
414414
outidx = Dict{Int, Int}()
415+
# mapping i -> index in aggbuf
416+
aggidx = Dict{Int, Int}()
415417
# mapping i -> f(fullstate, p, t) (component observables)
416418
obsfuns = Dict{Int, Function}()
417419
for (i, sni) in enumerate(_snis)
@@ -427,12 +429,28 @@ function SII.observed(nw::Network, snis)
427429
elseif (idx=findfirst(isequal(sni.subidx), obssym(cf))) != nothing #found in observed
428430
_obsf = _get_observed_f(nw, cf, resolvecompidx(nw, sni))
429431
obsfuns[i] = (u, outbuf, aggbuf, extbuf, p, t) -> _obsf(u, outbuf, aggbuf, extbuf, p, t)[idx]
432+
elseif hasinsym(cf) && sni.subidx insym_all(cf) # found in input
433+
if sni isa SymbolicVertexIndex
434+
idx = findfirst(isequal(sni.subidx), insym_all(cf))
435+
aggidx[i] = nw.im.v_aggr[resolvecompidx(nw, sni)][idx]
436+
elseif sni isa SymbolicEdgeIndex
437+
edge = nw.im.edgevec[resolvecompidx(nw, sni)]
438+
if (idx = findfirst(isequal(sni.subidx), insym(cf).src)) != nothing
439+
outidx[i] = nw.im.v_out[edge.src][idx]
440+
elseif (idx = findfirst(isequal(sni.subidx), insym(cf).dst)) != nothing
441+
outidx[i] = nw.im.v_out[edge.dst][idx]
442+
else
443+
error()
444+
end
445+
else
446+
error()
447+
end
430448
else
431449
throw(ArgumentError("Cannot resolve observable $sni"))
432450
end
433451
end
434452
end
435-
initbufs = !isempty(outidx) || !isempty(obsfuns)
453+
initbufs = !isempty(outidx) || !isempty(aggidx) || !isempty(obsfuns)
436454

437455
if isscalar
438456
@closure (u, p, t) -> begin
@@ -443,6 +461,9 @@ function SII.observed(nw::Network, snis)
443461
elseif !isempty(outidx)
444462
idx = only(outidx).second
445463
outbuf[idx]
464+
elseif !isempty(aggidx)
465+
idx = only(aggidx).second
466+
aggbuf[idx]
446467
else
447468
obsf = only(obsfuns).second
448469
obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
@@ -459,6 +480,9 @@ function SII.observed(nw::Network, snis)
459480
for (i, outi) in outidx
460481
out[i] = outbuf[outi]
461482
end
483+
for (i, aggi) in aggidx
484+
out[i] = aggbuf[aggi]
485+
end
462486
for (i, obsf) in obsfuns
463487
out[i] = obsf(u, outbuf, aggbuf, extbuf, p, t)::eltype(u)
464488
end

src/utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,6 @@ rand_inputs_fg(cf) = rand_inputs_fg(Random.default_rng(), cf)
134134
abstract type SymbolicIndex{C,S} end
135135
abstract type SymbolicStateIndex{C,S} <: SymbolicIndex{C,S} end
136136
abstract type SymbolicParameterIndex{C,S} <: SymbolicIndex{C,S} end
137+
138+
flatten_sym(v::NamedTuple) = reduce(vcat, values(v))
139+
flatten_sym(v::AbstractVector{Symbol}) = v

test/ComponentLibrary.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,19 @@ diffusion_vertex() = VertexModel(f=diffusionvertex!, dim=1, g=1:1)
4747
Base.@propagate_inbounds function kuramoto_edge!(e, θ_s, θ_d, (K,), t)
4848
e .= K .* sin(θ_s[1] - θ_d[1])
4949
end
50-
function kuramoto_edge(; name=:kuramoto_edge)
50+
function kuramoto_edge(; name=:kuramoto_edge, kwargs...)
5151
EdgeModel(;g=AntiSymmetric(kuramoto_edge!),
52-
outsym=[:P], psym=[:K], name)
52+
outsym=[:P], psym=[:K], name, kwargs...)
5353
end
5454

5555
Base.@propagate_inbounds function kuramoto_inertia!(dv, v, acc, p, t)
5656
M, D, Pm = p
5757
dv[1] = v[2]
5858
dv[2] = 1 / M * (Pm - D * v[2] + acc[1])
5959
end
60-
function kuramoto_second(; name=:kuramoto_second)
60+
function kuramoto_second(; name=:kuramoto_second, kwargs...)
6161
VertexModel(; f=kuramoto_inertia!, sym=[=>0, =>0],
62-
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name)
62+
psym=[:M=>1, :D=>0.1, :Pm=>1], g=StateMask(1), name, kwargs...)
6363
end
6464

6565
Base.@propagate_inbounds function kuramoto_vertex!(dθ, θ, esum, (ω,), t)

test/symbolicindexing_test.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ for idx in idxtypes
306306
println(idx, " => ", b.allocs, " allocations")
307307
end
308308
if VERSION v"1.11"
309-
@test b.allocs <= 12
309+
@test b.allocs <= 13
310310
end
311311
end
312312

@@ -470,3 +470,21 @@ nw = Network(g, [n1, n2, n3], [e1, e2])
470470
@test s.p.e[:e2, 1] == s[EPIndex(2,1)]
471471
@test s.p.e[:e3, 1] == s[EPIndex(3,1)]
472472
end
473+
474+
# test observed for inputs
475+
@testset "test observing of model input" begin
476+
v1 = Lib.kuramoto_second(name=:v1, vidx=1, insym=[:Pin])
477+
v2 = Lib.kuramoto_second(name=:v2, vidx=2, insym=[:Pin])
478+
v3 = Lib.kuramoto_second(name=:v3, vidx=3, insym=[:Pin])
479+
e1 = Lib.kuramoto_edge(name=:e1, src=1, dst=2, insym=[:δin])
480+
e2 = Lib.kuramoto_edge(name=:e2, src=2, dst=3, insym=[:δin])
481+
nw = Network([v1,v2,v3], [e1,e2])
482+
s = NWState(nw, rand(dim(nw)), rand(pdim(nw)))
483+
@test s[VIndex(:v1, :Pin)] == s[EIndex(:e1, :₋P)]
484+
@test s[VIndex(:v2, :Pin)] == s[EIndex(:e1, :P)] + s[EIndex(:e2, :₋P)]
485+
@test s[VIndex(:v3, :Pin)] == s[EIndex(:e2, :P)]
486+
@test s[EIndex(:e1, :src₊δin)] == s[VIndex(:v1, )]
487+
@test s[EIndex(:e1, :dst₊δin)] == s[VIndex(:v2, )]
488+
@test s[EIndex(:e2, :src₊δin)] == s[VIndex(:v2, )]
489+
@test s[EIndex(:e2, :dst₊δin)] == s[VIndex(:v3, )]
490+
end

test/testutils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
using CUDA
2+
using Adapt
3+
using NetworkDynamics: iscudacompatible, NaiveAggregator
4+
15
"""
26
Test utility, which rebuilds the Network with all different execution styles and compares the
37
results of the coreloop.

0 commit comments

Comments
 (0)