Skip to content

Commit 88375fc

Browse files
authored
refactor internals of ExtendedStateSpace to allow for named signals (#134)
* refactor internals of ExtendedStateSpace to allow for named signals * type preserving mapping functions * type preserving - * test ESS with named ss * handle empty arrays
1 parent 67f3081 commit 88375fc

File tree

6 files changed

+264
-110
lines changed

6 files changed

+264
-110
lines changed

src/ExtendedStateSpace.jl

Lines changed: 186 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
## Data Type Declarations ##
33
#####################################################################
44
"""
5-
ExtendedStateSpace{TE, T} <: AbstractStateSpace{TE}
5+
ExtendedStateSpace{TE, T, S, I} <: AbstractStateSpace{TE}
66
77
A type that represents the two-input, two-output system
88
```
@@ -18,7 +18,7 @@ where
1818
- `w` denotes external inputs, such as disturbances or references
1919
- `u` denotes control inputs
2020
21-
The call `lft(P, K)` forms the (lower) linear fractional transform
21+
The call `lft(P, K)` forms the (lower) linear fractional transform
2222
```
2323
z ┌─────┐ w
2424
◄──┤ │◄──
@@ -49,52 +49,88 @@ and the following design functions expect `ExtendedStateSpace` as inputs
4949
- [`LQGProblem`](@ref) (also accepts other types)
5050
5151
A video tutorial on how to use this type is available [here](https://youtu.be/huYRrn--AKc).
52+
53+
## Internal Representation
54+
Internally, this type stores an `AbstractStateSpace` along with index vectors that partition the inputs and outputs:
55+
- `sys`: The underlying state-space system
56+
- `w`: Indices for disturbance inputs (corresponds to B1)
57+
- `u`: Indices for control inputs (corresponds to B2)
58+
- `z`: Indices for performance outputs (corresponds to C1)
59+
- `y`: Indices for measured outputs (corresponds to C2)
60+
61+
The type parameter `I` allows index vectors to be `Vector{Int}`, `UnitRange{Int}`, or `Vector{Symbol}` (for NamedStateSpace).
5262
"""
53-
struct ExtendedStateSpace{TE,T} <: AbstractStateSpace{TE}
54-
A::Matrix{T}
55-
B1::Matrix{T}
56-
B2::Matrix{T}
57-
C1::Matrix{T}
58-
C2::Matrix{T}
59-
D11::Matrix{T}
60-
D12::Matrix{T}
61-
D21::Matrix{T}
62-
D22::Matrix{T}
63-
timeevol::TE
64-
function ExtendedStateSpace{TE,T}(
65-
A,
66-
B1,
67-
B2,
68-
C1,
69-
C2,
70-
D11,
71-
D12,
72-
D21,
73-
D22,
74-
timeevol::TE,
75-
) where {TE,T}
76-
nx = size(A, 1)
77-
nw = size(B1, 2)
78-
nu = size(B2, 2)
79-
nz = size(C1, 1)
80-
ny = size(C2, 1)
81-
82-
size(A, 2) != nx && nx != 0 && error("A must be square")
83-
size(B1, 1) == nx || error("B1 must have the same row size as A")
84-
size(B2, 1) == nx || error("B2 must have the same row size as A")
85-
size(C1, 2) == nx || error("C1 must have the same column size as A")
86-
size(C2, 2) == nx || error("C2 must have the same column size as A")
87-
size(D11, 2) == nw || error("D11 must have the same column size as B1")
88-
size(D21, 2) == nw || error("D21 must have the same column size as B1")
89-
size(D12, 2) == nu || error("D12 must have the same column size as B2")
90-
size(D22, 2) == nu || error("D22 must have the same column size as B2")
91-
size(D11, 1) == nz || error("D11 must have the same row size as C1")
92-
size(D12, 1) == nz || error("D12 must have the same row size as C1")
93-
size(D21, 1) == ny || error("D21 must have the same row size as C2")
94-
size(D22, 1) == ny || error("D22 must have the same row size as C2")
95-
96-
new{TE,T}(A, B1, B2, C1, C2, D11, D12, D21, D22, timeevol)
97-
end
63+
struct ExtendedStateSpace{TE, T, S<:AbstractStateSpace{TE}, I} <: AbstractStateSpace{TE}
64+
sys::S # The underlying StateSpace
65+
w::I # Disturbance input indices (corresponds to B1)
66+
u::I # Control input indices (corresponds to B2)
67+
z::I # Performance output indices (corresponds to C1)
68+
y::I # Measured output indices (corresponds to C2)
69+
end
70+
71+
# Inner constructor taking individual matrices (for backward compatibility)
72+
function ExtendedStateSpace{TE,T}(
73+
A,
74+
B1,
75+
B2,
76+
C1,
77+
C2,
78+
D11,
79+
D12,
80+
D21,
81+
D22,
82+
timeevol::TE,
83+
) where {TE,T}
84+
nx = size(A, 1)
85+
nw = size(B1, 2)
86+
nu = size(B2, 2)
87+
nz = size(C1, 1)
88+
ny = size(C2, 1)
89+
90+
size(A, 2) != nx && nx != 0 && error("A must be square")
91+
size(B1, 1) == nx || error("B1 must have the same row size as A")
92+
size(B2, 1) == nx || error("B2 must have the same row size as A")
93+
size(C1, 2) == nx || error("C1 must have the same column size as A")
94+
size(C2, 2) == nx || error("C2 must have the same column size as A")
95+
size(D11, 2) == nw || error("D11 must have the same column size as B1")
96+
size(D21, 2) == nw || error("D21 must have the same column size as B1")
97+
size(D12, 2) == nu || error("D12 must have the same column size as B2")
98+
size(D22, 2) == nu || error("D22 must have the same column size as B2")
99+
size(D11, 1) == nz || error("D11 must have the same row size as C1")
100+
size(D12, 1) == nz || error("D12 must have the same row size as C1")
101+
size(D21, 1) == ny || error("D21 must have the same row size as C2")
102+
size(D22, 1) == ny || error("D22 must have the same row size as C2")
103+
104+
# Build combined matrices
105+
B = [B1 B2]
106+
C = [C1; C2]
107+
D = [D11 D12; D21 D22]
108+
sys = StateSpace{TE, T}(A, B, C, D, timeevol)
109+
110+
# Compute index vectors
111+
w_inds = 1:nw
112+
u_inds = (nw+1):(nw+nu)
113+
z_inds = 1:nz
114+
y_inds = (nz+1):(nz+ny)
115+
116+
ExtendedStateSpace{TE, T, typeof(sys), typeof(w_inds)}(sys, w_inds, u_inds, z_inds, y_inds)
117+
end
118+
119+
# Constructor with all 4 type parameters (for type-preserving operations like negation)
120+
function ExtendedStateSpace{TE, T, S, I}(
121+
A,
122+
B1,
123+
B2,
124+
C1,
125+
C2,
126+
D11,
127+
D12,
128+
D21,
129+
D22,
130+
timeevol::TE,
131+
) where {TE, T, S, I}
132+
# Delegate to the simpler constructor
133+
ExtendedStateSpace{TE, T}(A, B1, B2, C1, C2, D11, D12, D21, D22, timeevol)
98134
end
99135

100136
function ExtendedStateSpace(
@@ -133,6 +169,26 @@ function ExtendedStateSpace(
133169
)
134170
end
135171

172+
"""
173+
ExtendedStateSpace(sys::AbstractStateSpace, w, u, z, y)
174+
175+
Create an [`ExtendedStateSpace`](@ref) from an existing state-space system with specified index vectors.
176+
177+
This constructor preserves the type of `sys` (e.g., `NamedStateSpace`).
178+
179+
# Arguments
180+
- `sys`: The underlying state-space system
181+
- `w`: Disturbance input indices (corresponds to B1)
182+
- `u`: Control input indices (corresponds to B2)
183+
- `z`: Performance output indices (corresponds to C1)
184+
- `y`: Measured output indices (corresponds to C2)
185+
"""
186+
function ExtendedStateSpace(sys::S, w::I, u::I, z::I, y::I) where {S<:AbstractStateSpace, I}
187+
TE = typeof(sys.timeevol)
188+
T = ControlSystemsBase.numeric_type(sys)
189+
ExtendedStateSpace{TE, T, S, I}(sys, w, u, z, y)
190+
end
191+
136192
"""
137193
ss(A, B1, B2, C1, C2, D11, D12, D21, D22 [, Ts])
138194
@@ -181,53 +237,92 @@ function ss(
181237
return ExtendedStateSpace(A, B1, B2, C1, C2, D11, D12, D21, D22, Ts)
182238
end
183239

184-
function Base.promote_rule(::Type{StateSpace{TE, F1}}, ::Type{ExtendedStateSpace{TE, F2}}) where {TE, F1, F2}
240+
function Base.promote_rule(::Type{StateSpace{TE, F1}}, ::Type{<:ExtendedStateSpace{TE, F2}}) where {TE, F1, F2}
185241
ExtendedStateSpace{TE, promote_type(F1, F2)}
186242
end
187243

188-
function Base.convert(::Type{ExtendedStateSpace{TE, F2}}, s::StateSpace{TE, F1})where {TE, F1, F2}
244+
function Base.convert(::Type{<:ExtendedStateSpace{TE}}, s::StateSpace{TE, F1}) where {TE, F1}
189245
partition(s, 0, 0)
190246
end
191247

192-
function Base.getproperty(sys::ExtendedStateSpace, s::Symbol)
193-
if s === :Ts
194-
# if !isdiscrete(sys) # NOTE this line seems to be breaking inference of isdiscrete (is there a test for this?)
195-
if isdiscrete(sys)
196-
return timeevol(sys).Ts
248+
function Base.getproperty(esys::ExtendedStateSpace, s::Symbol)
249+
# Access to underlying system and index vectors
250+
if s === :sys || s === :w || s === :u || s === :z || s === :y
251+
return getfield(esys, s)
252+
end
253+
254+
# Get the underlying system for matrix extraction
255+
sys = getfield(esys, :sys)
256+
if sys isa NamedStateSpace
257+
w = names2indices(getfield(esys, :w), sys.u)
258+
u = names2indices(getfield(esys, :u), sys.u)
259+
z = names2indices(getfield(esys, :z), sys.y)
260+
y = names2indices(getfield(esys, :y), sys.y)
261+
else
262+
w = getfield(esys, :w)
263+
u = getfield(esys, :u)
264+
z = getfield(esys, :z)
265+
y = getfield(esys, :y)
266+
end
267+
268+
# Extract matrices via indexing
269+
if s === :A
270+
return sys.A
271+
elseif s === :B1
272+
return sys.B[:, w]
273+
elseif s === :B2
274+
return sys.B[:, u]
275+
elseif s === :C1
276+
return sys.C[z, :]
277+
elseif s === :C2
278+
return sys.C[y, :]
279+
elseif s === :D11
280+
return sys.D[z, w]
281+
elseif s === :D12
282+
return sys.D[z, u]
283+
elseif s === :D21
284+
return sys.D[y, w]
285+
elseif s === :D22
286+
return sys.D[y, u]
287+
elseif s === :timeevol
288+
return sys.timeevol
289+
elseif s === :Ts
290+
if isdiscrete(esys)
291+
return timeevol(esys).Ts
197292
else
198293
@warn "Getting time 0.0 for non-discrete systems is deprecated. Check `isdiscrete` before trying to access time."
199294
return 0.0
200295
end
201296
elseif s === :nx
202-
return nstates(sys)
297+
return nstates(esys)
203298
elseif s === :nu
204-
return size(sys.B2, 2)
205-
elseif s === :ny # TODO: now size(sys.C, 1) is not always the same as sys.ny
206-
return size(sys.C2, 1)
299+
return length(u)
300+
elseif s === :ny
301+
return length(y)
207302
elseif s === :nw
208-
return size(sys.B1, 2)
303+
return length(w)
209304
elseif s === :nz
210-
return size(sys.C1, 1)
305+
return length(z)
211306
elseif s === :B
212-
[sys.B1 sys.B2]
307+
return sys.B
213308
elseif s === :C
214-
[sys.C1; sys.C2]
309+
return sys.C
215310
elseif s === :D
216-
[sys.D11 sys.D12; sys.D21 sys.D22]
311+
return sys.D
217312
elseif s === :zinds
218-
return 1:size(sys.C1, 1)
313+
return z
219314
elseif s === :yinds
220-
return size(sys.C1, 1) .+ (1:size(sys.C2, 1))
315+
return y
221316
elseif s === :winds
222-
return 1:size(sys.B1, 2)
317+
return w
223318
elseif s === :uinds
224-
return size(sys.B1, 2) .+ (1:size(sys.B2, 2))
319+
return u
225320
else
226-
return getfield(sys, s)
321+
error("type ExtendedStateSpace has no field $s")
227322
end
228323
end
229324

230-
Base.propertynames(sys::ExtendedStateSpace) = (:A, :B, :C, :D, :B1, :B2, :C1, :C2, :D11, :D12, :D21, :D22, :Ts, :timeevol, :nx, :ny, :nu, :nw, :nz, :zinds, :yinds, :winds, :uinds)
325+
Base.propertynames(::ExtendedStateSpace) = (:A, :B, :C, :D, :B1, :B2, :C1, :C2, :D11, :D12, :D21, :D22, :Ts, :timeevol, :nx, :ny, :nu, :nw, :nz, :zinds, :yinds, :winds, :uinds, :sys, :w, :u, :z, :y)
231326

232327
ControlSystemsBase.StateSpace(s::ExtendedStateSpace) = ss(ssdata(s)..., s.timeevol)
233328

@@ -315,14 +410,29 @@ end
315410

316411
function Base.:*(s1::ExtendedStateSpace, s2::Number)
317412
A, B1, B2, C1, C2, D11, D12, D21, D22 = ssdata_e(s1)
413+
# The reason for only scaling one channel is the use in UncertainSS
318414
ss(A, s2*B1, B2, C1, C2, s2*D11, D12, s2*D21, D22, s1.timeevol)
415+
# ExtendedStateSpace(s1.sys*s2, s1.w, s1.u, s1.z, s1.y)
319416
end
320417

321418
function Base.:*(s2::Number, s1::ExtendedStateSpace)
322419
A, B1, B2, C1, C2, D11, D12, D21, D22 = ssdata_e(s1)
323420
ss(A, B1, B2, s2*C1, C2, s2*D11, s2*D12, D21, D22, s1.timeevol)
421+
# ExtendedStateSpace(s2*s1.sys, s1.w, s1.u, s1.z, s1.y)
324422
end
325423

424+
# function invert_mappings(s::ExtendedStateSpace)
425+
# # Reorder system: inputs [u; w] and outputs [y; z]
426+
# # This swaps the w↔u and z↔y mappings
427+
# (; w,u,z,y) = s
428+
# new_i = [u; w]
429+
# new_o = [y; z]
430+
# ExtendedStateSpace(s.sys, s.u, s.w, s.y, s.z)
431+
# # ExtendedStateSpace(s.sys[new_o, new_i], s.u, s.w, s.y, s.z)
432+
# # ExtendedStateSpace(s.sys[new_o, new_i], s.w, s.u, s.z, s.y)
433+
# end
434+
435+
326436
function invert_mappings(s::ExtendedStateSpace)
327437
A, B1, B2, C1, C2, D11, D12, D21, D22 = ssdata_e(s)
328438
ss(A, B2, B1, C2, C1, D22, D21, D12, D11, s.timeevol)
@@ -341,9 +451,9 @@ end
341451
# end
342452

343453
## NEGATION ##
344-
function Base.:-(sys::ST) where ST <: ExtendedStateSpace
345-
A, B1, B2, C1, C2, D11, D12, D21, D22 = ssdata_e(sys)
346-
ST(A, B1, B2, -C1, -C2, -D11, -D12, -D21, -D22, sys.timeevol)
454+
function Base.:-(sys::ExtendedStateSpace)
455+
# Negating a StateSpace negates C and D matrices, preserving the internal sys type
456+
ExtendedStateSpace(-sys.sys, sys.w, sys.u, sys.z, sys.y)
347457
end
348458

349459
#####################################################################
@@ -356,17 +466,7 @@ Base.eltype(::Type{S}) where {S<:ExtendedStateSpace} = S
356466
ControlSystemsBase.numeric_type(sys::ExtendedStateSpace) = eltype(sys.A)
357467

358468
function Base.getindex(sys::ExtendedStateSpace, inds...)
359-
if size(inds, 1) != 2
360-
error("Must specify 2 indices to index statespace model")
361-
end
362-
rows, cols = ControlSystemsBase.index2range(inds...) # FIXME: ControlSystemsBase.index2range(inds...)
363-
return ss(
364-
copy(sys.A),
365-
sys.B[:, cols],
366-
sys.C[rows, :],
367-
sys.D[rows, cols],
368-
sys.timeevol,
369-
)
469+
getindex(sys.sys, inds...)
370470
end
371471

372472
#####################################################################
@@ -725,7 +825,7 @@ Return the system from u -> y
725825
See also [`performance_mapping`](@ref), [`system_mapping`](@ref), [`noise_mapping`](@ref)
726826
"""
727827
function system_mapping(P::ExtendedStateSpace, sminreal=sminreal)
728-
sminreal(ss(P.A, P.B2, P.C2, P.D22, P.timeevol))
828+
sminreal(P.sys[P.y, P.u])
729829
end
730830

731831
"""
@@ -735,7 +835,7 @@ Return the system from w -> z
735835
See also [`performance_mapping`](@ref), [`system_mapping`](@ref), [`noise_mapping`](@ref)
736836
"""
737837
function performance_mapping(P::ExtendedStateSpace, sminreal=sminreal)
738-
sminreal(ss(P.A, P.B1, P.C1, P.D11, P.timeevol))
838+
sminreal(P.sys[P.z, P.w])
739839
end
740840

741841
"""
@@ -745,5 +845,5 @@ Return the system from w -> y
745845
See also [`performance_mapping`](@ref), [`system_mapping`](@ref), [`noise_mapping`](@ref)
746846
"""
747847
function noise_mapping(P::ExtendedStateSpace, sminreal=sminreal)
748-
sminreal(ss(P.A, P.B1, P.C2, P.D21, P.timeevol))
848+
sminreal(P.sys[P.y, P.w])
749849
end

0 commit comments

Comments
 (0)