Skip to content

Commit be2c14b

Browse files
authored
Merge pull request #192 from JuliaDynamics/hw/obsexp
add ObservableExpressions to create derived quantities on the go
2 parents 357778f + 696c378 commit be2c14b

File tree

9 files changed

+191
-3
lines changed

9 files changed

+191
-3
lines changed

Project.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,19 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2929
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
3030
StyledStrings = "f489334b-da3d-4c2e-b8f0-e476e12c162b"
3131
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
32-
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3332
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
3433

3534
[weakdeps]
3635
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3736
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
37+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
3838
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
39+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
3940

4041
[extensions]
4142
CUDAExt = ["CUDA", "Adapt"]
42-
MTKExt = "ModelingToolkit"
43+
MTKExt = ["ModelingToolkit", "Symbolics"]
44+
SymbolicsExt = ["Symbolics", "MacroTools"]
4345

4446
[compat]
4547
Adapt = "4.0.4"
@@ -53,6 +55,7 @@ Graphs = "1"
5355
InteractiveUtils = "1"
5456
KernelAbstractions = "0.9.18"
5557
LinearAlgebra = "1"
58+
MacroTools = "0.5.15"
5659
Mixers = "0.1.2"
5760
ModelingToolkit = "9"
5861
NNlib = "0.9.13"

docs/src/API.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ VIndex
8282
EIndex
8383
VPIndex
8484
EPIndex
85+
@obsex
8586
```
8687

8788
### Index generators

docs/src/symbolic_indexing.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,20 @@ Observables can be accessed like any other state, for example, the flows in the
114114
```@example si
115115
plot(sol; idxs=eidxs(nw, :, :flow))
116116
```
117+
118+
## Derived `ObservableExpressions` using `@obsex`
119+
120+
Sometimes it is usefull to plot or observe some simple derived quantity.For that,
121+
one can used the [`@obsex`](@ref) macro, to define simple derived quantities.
122+
123+
For example, we can directly plot the storage difference with respect to storage of node 1.
124+
125+
```@example si
126+
plot(sol; idxs=@obsex(vidxs(nw,:,:storage) .- VIndex(1,:storage)))
127+
```
128+
129+
Other examples are the calculation of magnitude and argument of complex values which are modeld in real and imaginary part.
130+
```
131+
@obsex mag = sqrt(VIndex(1, :u_r)^2 + VIndex(2, :u_i)^2)
132+
```
133+

ext/MTKExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit: Symbolic, iscall, operation, arguments, build_function
44
using ModelingToolkit: ModelingToolkit, Equation, ODESystem, Differential
55
using ModelingToolkit: full_equations, get_variables, structural_simplify, getname, unwrap
66
using ModelingToolkit: full_parameters, unknowns, independent_variables, observed, defaults
7-
using ModelingToolkit.Symbolics: Symbolics, fixpoint_sub, substitute
7+
using Symbolics: Symbolics, fixpoint_sub, substitute
88
using RecursiveArrayTools: RecursiveArrayTools
99
using ArgCheck: @argcheck
1010
using LinearAlgebra: Diagonal, I

ext/SymbolicsExt.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
module SymbolicsExt
2+
3+
using Symbolics: Symbolics, @variables, Num
4+
using MacroTools: postwalk, @capture
5+
using NetworkDynamics: NetworkDynamics, ObservableExpression, SymbolicVertexIndex, SymbolicEdgeIndex, SII
6+
7+
function NetworkDynamics.generate_observable_expression(ex::Expr)
8+
if @capture(ex, capname_ = capcontent_)
9+
name = capname
10+
content = capcontent
11+
else
12+
name = nothing
13+
content = ex
14+
end
15+
# manually hygen on "mapping", otherwise we cannot escape x
16+
mapping_sym = gensym("mapping")
17+
symbolic_expr = postwalk(content) do x
18+
if x (:+, :-, :^, :/, :.+, :.-, :.^, :./, Symbol(":")) || x isa QuoteNode
19+
x
20+
else
21+
:(NetworkDynamics.collect_symbol!($mapping_sym, $x))
22+
end
23+
end
24+
quote
25+
$(esc(mapping_sym)) = Dict()
26+
expr = $(esc(symbolic_expr))
27+
NetworkDynamics.ObservableExpression($(esc(mapping_sym)), expr, $(Meta.quot(name)))
28+
end
29+
end
30+
31+
function NetworkDynamics.ObservableExpression(mapping, expr::Vector, names)
32+
if names isa Symbol
33+
names = [Symbol(names, NetworkDynamics.subscript(i)) for i in 1:length(expr)]
34+
elseif isnothing(names)
35+
names = Iterators.repeated(nothing, length(expr))
36+
end
37+
[NetworkDynamics.ObservableExpression(mapping, e, n) for (e, n) in zip(expr, names)]
38+
end
39+
40+
function NetworkDynamics.ObservableExpression(mapping, expr, name)
41+
nwidx = collect(keys(mapping))
42+
syms = collect(values(mapping))
43+
f = Symbolics.build_function(expr, syms; expression=Val(false))
44+
NetworkDynamics.ObservableExpression(nwidx, f, expr, name)
45+
end
46+
47+
NetworkDynamics.collect_symbol!(_, x) = x
48+
function NetworkDynamics.collect_symbol!(mapping, x::Union{AbstractVector, Tuple})
49+
map(el -> NetworkDynamics.collect_symbol!(mapping, el), x)
50+
end
51+
function NetworkDynamics.collect_symbol!(mapping, nwindex::Union{SymbolicVertexIndex,SymbolicEdgeIndex})
52+
if haskey(mapping, nwindex)
53+
return mapping[nwindex]
54+
else
55+
name = SII.getname(nwindex)
56+
sym = only(@variables $name)
57+
mapping[nwindex] = sym
58+
return sym
59+
end
60+
end
61+
62+
end # module

src/NetworkDynamics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ include("external_inputs.jl")
6161
export VIndex, EIndex, VPIndex, EPIndex, NWState, NWParameter, uflat, pflat
6262
export vidxs, eidxs, vpidxs, epidxs
6363
export save_parameters!
64+
export @obsex
6465
include("symbolicindexing.jl")
6566

6667
export has_metadata, get_metadata, set_metadata!

src/symbolicindexing.jl

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,10 @@ function observed_symbols(nw::Network)
402402
end
403403

404404
function SII.observed(nw::Network, snis)
405+
if (snis isa AbstractVector || snis isa Tuple) && any(sni -> sni isa ObservableExpression, snis)
406+
throw(ArgumentError("Cannot mix normal symbolic indices with @obsex currently!"))
407+
end
408+
405409
_snis = _expand_and_collect(nw, snis)
406410
isscalar = _snis isa SymbolicIndex
407411
if isscalar
@@ -1079,3 +1083,80 @@ extract_nw(p::IndexingProxy) = extract_nw(p.s)
10791083
function extract_nw(::Nothing)
10801084
throw(ArgumentError("Needs system context to generate matching indices. Pass Network, sol, prob, ..."))
10811085
end
1086+
1087+
####
1088+
#### Observable Expressions
1089+
####
1090+
struct ObservableExpression{VT,F,Ex,N}
1091+
inputs::Vector{VT}
1092+
f::F
1093+
ex::Ex
1094+
name::N
1095+
end
1096+
function Base.show(io::IO, mime::MIME"text/plain", obsex::ObservableExpression)
1097+
print(io, "ObservableExpression(")
1098+
isnothing(obsex.name) || print(io, obsex.name, " = ")
1099+
show(io, mime, obsex.ex)
1100+
print(io, ")")
1101+
end
1102+
1103+
"""
1104+
@obsex([name =] expression)
1105+
1106+
Define observable expressions, which are simple combinations of knonw
1107+
states/parameters/observables. `@obsex(...)` returns an `ObservableExpression`
1108+
which can be used as an symbolic index. This is mainly intended for quick
1109+
plotting or export of common "derived" variables, such as the argument of a
1110+
2-component complex state. For example:
1111+
1112+
sol(t; idxs=@obsex(arg = atan(VIndex(1,:u_i), VIndex(1,:u_r))]
1113+
sol(t; idxs=@obsex(δrel = VIndex(1,:δ) - VIndex(2,:δ)))
1114+
1115+
"""
1116+
macro obsex(ex)
1117+
generate_observable_expression(ex)
1118+
end
1119+
1120+
function generate_observable_expression(::Any)
1121+
error("@obsex can only be used when Symbolics.jl is loaded.")
1122+
end
1123+
1124+
# define function stub to overload in SymbolicsExt
1125+
function collect_symbol! end
1126+
1127+
function SII.is_observed(nw::Network, obsex::ObservableExpression)
1128+
true
1129+
end
1130+
1131+
SII.symbolic_type(::Type{<:ObservableExpression}) = SII.ScalarSymbolic()
1132+
SII.hasname(::ObservableExpression) = true
1133+
function SII.getname(obsex::ObservableExpression)
1134+
if obsex.name isa Symbol
1135+
return obsex.name
1136+
else
1137+
io = IOBuffer()
1138+
show(io, MIME"text/plain"(), obsex.ex)
1139+
str = String(take!(io))
1140+
return Symbol(replace(str, " "=>""))
1141+
end
1142+
end
1143+
1144+
function SII.observed(nw::Network, obsex::ObservableExpression)
1145+
inputf = SII.observed(nw, obsex.inputs)
1146+
(u, p, t) -> begin
1147+
input = inputf(u, p, t)
1148+
obsex.f(input)
1149+
end
1150+
end
1151+
function SII.observed(nw::Network, obsexs::AbstractVector{<:ObservableExpression})
1152+
inputfs = map(obsex -> SII.observed(nw, obsex.inputs), obsexs)
1153+
(u, p, t) -> begin
1154+
map(obsexs, inputfs) do obsex, inputf
1155+
input = inputf(u, p, t)
1156+
obsex.f(input)
1157+
end
1158+
end
1159+
end
1160+
1161+
Base.getindex(s::NWState, idx::ObservableExpression) = SII.getu(s, idx)(s)
1162+
Base.getindex(s::NWParameter, idx::ObservableExpression) = SII.getp(s, idx)(s)

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
4141
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4242
SteadyStateDiffEq = "9672c7b4-1e72-59bd-8a11-6ac3964bc41f"
4343
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
44+
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4445
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4546
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
4647

test/symbolicindexing_test.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ using Graphs
33
using OrdinaryDiffEqTsit5
44
using Chairmarks
55
using Test
6+
using Symbolics
67
import SymbolicIndexingInterface as SII
78
using NetworkDynamics: VIndex, EIndex, VPIndex, EPIndex, _resolve_colon
89

@@ -488,3 +489,24 @@ end
488489
@test s[EIndex(:e2, :src₊δin)] == s[VIndex(:v2, )]
489490
@test s[EIndex(:e2, :dst₊δin)] == s[VIndex(:v3, )]
490491
end
492+
493+
@testset "test observed expressions" begin
494+
v1 = Lib.kuramoto_second(name=:v1, vidx=1, insym=[:Pin])
495+
v2 = Lib.kuramoto_second(name=:v2, vidx=2, insym=[:Pin])
496+
v3 = Lib.kuramoto_second(name=:v3, vidx=3, insym=[:Pin])
497+
e1 = Lib.kuramoto_edge(name=:e1, src=1, dst=2, insym=[:δin])
498+
e2 = Lib.kuramoto_edge(name=:e2, src=2, dst=3, insym=[:δin])
499+
nw = Network([v1,v2,v3], [e1,e2])
500+
s = NWState(nw, rand(dim(nw)), rand(pdim(nw)))
501+
502+
obsex = @obsex(VIndex(1,) + VIndex(2,))
503+
@test s[obsex] == s[VIndex(1,)] + s[VIndex(2,)]
504+
@test s[@obsex VIndex(1,) - EIndex(:e1, :src₊δin)] == 0
505+
506+
@test SII.getname(@obsex(VIndex(1,) + VIndex(2,))) == Symbol("v1₊δ+v2₊δ")
507+
@test SII.getname(@obsex(δ²=VIndex(1,)^2)) == :δ²
508+
509+
@test s[@obsex(vidxs(s, :, ) .- VIndex(1, ))] == [0, s[VIndex(2,)] - s[VIndex(1,)], s[VIndex(3,)] - s[VIndex(1,)]]
510+
511+
obsex = @obsex(δ_rel = vidxs(s, :, ) .- VIndex(1, ))
512+
end

0 commit comments

Comments
 (0)