Skip to content

Commit 891998f

Browse files
fixed abstractstates and added examples to docstring
1 parent 29af05f commit 891998f

File tree

1 file changed

+229
-54
lines changed

1 file changed

+229
-54
lines changed

src/states.jl

Lines changed: 229 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,100 @@ end
1616
"""
1717
StandardStates()
1818
19-
When this struct is employed, the states of the reservoir are not modified. It represents the default behavior
20-
in scenarios where no specific state modification is required. This approach is ideal for applications
21-
where the inherent dynamics of the reservoir are sufficient, and no external manipulation of the states
22-
is necessary. It maintains the original state representation, ensuring that the reservoir's natural properties
23-
are preserved and utilized in computations.
19+
When this struct is employed, the states of the reservoir are not modified.
20+
21+
# Example
22+
23+
```jldoctest
24+
julia> states = StandardStates()
25+
StandardStates()
26+
27+
julia> test_vec = zeros(Float32, 5)
28+
5-element Vector{Float32}:
29+
0.0
30+
0.0
31+
0.0
32+
0.0
33+
0.0
34+
35+
julia> new_vec = states(test_vec)
36+
5-element Vector{Float32}:
37+
0.0
38+
0.0
39+
0.0
40+
0.0
41+
0.0
42+
43+
julia> test_mat = zeros(Float32, 5, 5)
44+
5×5 Matrix{Float32}:
45+
0.0 0.0 0.0 0.0 0.0
46+
0.0 0.0 0.0 0.0 0.0
47+
0.0 0.0 0.0 0.0 0.0
48+
0.0 0.0 0.0 0.0 0.0
49+
0.0 0.0 0.0 0.0 0.0
50+
51+
julia> new_mat = states(test_mat)
52+
5×5 Matrix{Float32}:
53+
0.0 0.0 0.0 0.0 0.0
54+
0.0 0.0 0.0 0.0 0.0
55+
0.0 0.0 0.0 0.0 0.0
56+
0.0 0.0 0.0 0.0 0.0
57+
0.0 0.0 0.0 0.0 0.0
58+
```
2459
"""
2560
struct StandardStates <: AbstractStates end
2661

2762
"""
2863
ExtendedStates()
2964
30-
The `ExtendedStates` struct is used to extend the reservoir states by
31-
vertically concatenating the input data (during training) and the prediction data (during the prediction phase).
32-
This method enriches the state representation by integrating external data, enhancing the model's capability
33-
to capture and utilize complex patterns in both training and prediction stages.
65+
The `ExtendedStates` struct is used to extend the reservoir
66+
states by vertically concatenating the input data (during training)
67+
and the prediction data (during the prediction phase).
68+
69+
# Example
70+
71+
```jldoctest
72+
julia> states = ExtendedStates()
73+
ExtendedStates()
74+
75+
julia> test_vec = zeros(Float32, 5)
76+
5-element Vector{Float32}:
77+
0.0
78+
0.0
79+
0.0
80+
0.0
81+
0.0
82+
83+
julia> new_vec = states(test_vec, fill(3.0f0, 3))
84+
8-element Vector{Float32}:
85+
0.0
86+
0.0
87+
0.0
88+
0.0
89+
0.0
90+
3.0
91+
3.0
92+
3.0
93+
94+
julia> test_mat = zeros(Float32, 5, 5)
95+
5×5 Matrix{Float32}:
96+
0.0 0.0 0.0 0.0 0.0
97+
0.0 0.0 0.0 0.0 0.0
98+
0.0 0.0 0.0 0.0 0.0
99+
0.0 0.0 0.0 0.0 0.0
100+
0.0 0.0 0.0 0.0 0.0
101+
102+
julia> new_mat = states(test_mat, fill(3.0f0, 3))
103+
8×5 Matrix{Float32}:
104+
0.0 0.0 0.0 0.0 0.0
105+
0.0 0.0 0.0 0.0 0.0
106+
0.0 0.0 0.0 0.0 0.0
107+
0.0 0.0 0.0 0.0 0.0
108+
0.0 0.0 0.0 0.0 0.0
109+
3.0 3.0 3.0 3.0 3.0
110+
3.0 3.0 3.0 3.0 3.0
111+
3.0 3.0 3.0 3.0 3.0
112+
```
34113
"""
35114
struct ExtendedStates <: AbstractStates end
36115

@@ -46,12 +125,50 @@ end
46125
PaddedStates(padding)
47126
PaddedStates(;padding=1.0)
48127
49-
Creates an instance of the `PaddedStates` struct with specified padding value.
50-
This padding is typically set to 1.0 by default but can be customized.
51-
The states of the reservoir are padded by vertically concatenating this padding value,
52-
enhancing the dimensionality and potentially improving the performance of the reservoir computing model.
53-
This function is particularly useful in scenarios where adding a constant baseline to the states is necessary
54-
for the desired computational task.
128+
Creates an instance of the `PaddedStates` struct with specified
129+
padding value (default 1.0). The states of the reservoir are padded
130+
by vertically concatenating the padding value.
131+
132+
# Example
133+
134+
```jldoctest
135+
julia> states = PaddedStates(1.0)
136+
PaddedStates{Float64}(1.0)
137+
138+
julia> test_vec = zeros(Float32, 5)
139+
5-element Vector{Float32}:
140+
0.0
141+
0.0
142+
0.0
143+
0.0
144+
0.0
145+
146+
julia> new_vec = states(test_vec)
147+
6-element Vector{Float32}:
148+
0.0
149+
0.0
150+
0.0
151+
0.0
152+
0.0
153+
1.0
154+
155+
julia> test_mat = zeros(Float32, 5, 5)
156+
5×5 Matrix{Float32}:
157+
0.0 0.0 0.0 0.0 0.0
158+
0.0 0.0 0.0 0.0 0.0
159+
0.0 0.0 0.0 0.0 0.0
160+
0.0 0.0 0.0 0.0 0.0
161+
0.0 0.0 0.0 0.0 0.0
162+
163+
julia> new_mat = states(test_mat)
164+
6×5 Matrix{Float32}:
165+
0.0 0.0 0.0 0.0 0.0
166+
0.0 0.0 0.0 0.0 0.0
167+
0.0 0.0 0.0 0.0 0.0
168+
0.0 0.0 0.0 0.0 0.0
169+
0.0 0.0 0.0 0.0 0.0
170+
1.0 1.0 1.0 1.0 1.0
171+
```
55172
"""
56173
function PaddedStates(; padding=1.0)
57174
return PaddedStates(padding)
@@ -61,49 +178,118 @@ end
61178
PaddedExtendedStates(padding)
62179
PaddedExtendedStates(;padding=1.0)
63180
64-
Constructs a `PaddedExtendedStates` struct, which first extends the reservoir states with training or prediction data,
65-
then pads them with a specified value (defaulting to 1.0). This process is achieved through vertical concatenation,
66-
combining the padding value, data, and states.
67-
This function is particularly useful for enhancing the reservoir's state representation in more complex scenarios,
68-
where both extended contextual information and consistent baseline padding are crucial for the computational
69-
effectiveness of the reservoir computing model.
181+
Constructs a `PaddedExtendedStates` struct, which first extends
182+
the reservoir states with training or prediction data,then pads them
183+
with a specified value (defaulting to 1.0).
184+
185+
# Example
186+
187+
```jldoctest
188+
julia> states = PaddedExtendedStates(1.0)
189+
PaddedExtendedStates{Float64}(1.0)
190+
191+
julia> test_vec = zeros(Float32, 5)
192+
5-element Vector{Float32}:
193+
0.0
194+
0.0
195+
0.0
196+
0.0
197+
0.0
198+
199+
julia> new_vec = states(test_vec, fill(3.0f0, 3))
200+
9-element Vector{Float32}:
201+
0.0
202+
0.0
203+
0.0
204+
0.0
205+
0.0
206+
1.0
207+
3.0
208+
3.0
209+
3.0
210+
211+
julia> test_mat = zeros(Float32, 5, 5)
212+
5×5 Matrix{Float32}:
213+
0.0 0.0 0.0 0.0 0.0
214+
0.0 0.0 0.0 0.0 0.0
215+
0.0 0.0 0.0 0.0 0.0
216+
0.0 0.0 0.0 0.0 0.0
217+
0.0 0.0 0.0 0.0 0.0
218+
219+
julia> new_mat = states(test_mat, fill(3.0f0, 3))
220+
9×5 Matrix{Float32}:
221+
0.0 0.0 0.0 0.0 0.0
222+
0.0 0.0 0.0 0.0 0.0
223+
0.0 0.0 0.0 0.0 0.0
224+
0.0 0.0 0.0 0.0 0.0
225+
0.0 0.0 0.0 0.0 0.0
226+
1.0 1.0 1.0 1.0 1.0
227+
3.0 3.0 3.0 3.0 3.0
228+
3.0 3.0 3.0 3.0 3.0
229+
3.0 3.0 3.0 3.0 3.0
230+
```
70231
"""
71232
function PaddedExtendedStates(; padding=1.0)
72233
return PaddedExtendedStates(padding)
73234
end
74235

75236
#functions of the states to apply modifications
76-
function (::StandardStates)(nla_type, x, y)
77-
return nla(nla_type, x)
237+
function (::StandardStates)(nla_type::NonLinearAlgorithm,
238+
state, inp)
239+
return nla(nla_type, state)
78240
end
79241

80-
function (::ExtendedStates)(nla_type, x, y)
81-
x_tmp = vcat(y, x)
82-
return nla(nla_type, x_tmp)
242+
(::StandardStates)(state) = state
243+
244+
function (states_type::ExtendedStates)(nla_type::NonLinearAlgorithm,
245+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
246+
return nla(nla_type, states_type(state, inp))
83247
end
84248

85-
#check matrix/vector
86-
function (states_type::PaddedStates)(nla_type, x, y)
87-
tt = typeof(first(x))
88-
x_tmp = vcat(fill(tt(states_type.padding), (1, size(x, 2))), x)
89-
#x_tmp = reduce(vcat, x_tmp)
90-
return nla(nla_type, x_tmp)
249+
function (states_type::PaddedStates)(nla_type::NonLinearAlgorithm,
250+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
251+
return nla(nla_type, states_type(state))
91252
end
92253

93-
#check matrix/vector
94-
function (states_type::PaddedExtendedStates)(nla_type, x, y)
95-
tt = typeof(first(x))
96-
x_tmp = vcat(y, x)
97-
x_tmp = vcat(fill(tt(states_type.padding), (1, size(x, 2))), x_tmp)
98-
#x_tmp = reduce(vcat, x_tmp)
99-
return nla(nla_type, x_tmp)
254+
function (states_type::PaddedExtendedStates)(nla_type::NonLinearAlgorithm,
255+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
256+
return nla(nla_type, states_type(state, inp))
100257
end
101258

102-
#non linear algorithms
103-
## conform to current (0.10.5) approach
104-
nla(nlat::NonLinearAlgorithm, x_old) = nlat(x_old)
259+
function (states_type::PaddedExtendedStates)(state::AbstractVecOrMat,
260+
inp::AbstractVecOrMat)
261+
x_pad = PaddedStates(states_type.padding)(state)
262+
x_ext = ExtendedStates()(x_pad, inp)
263+
return x_ext
264+
end
265+
266+
function (states_type::ExtendedStates)(mat::AbstractMatrix, inp::AbstractVector)
267+
results = Vector{Vector{eltype(mat)}}(undef, size(mat, 2))
268+
for (idx, col) in enumerate(eachcol(mat))
269+
results[idx] = states_type(col, inp)
270+
end
271+
return hcat(results...)
272+
end
273+
274+
function (::ExtendedStates)(vect::AbstractVector, inp::AbstractVector)
275+
return x_tmp = vcat(vect, inp)
276+
end
277+
278+
function (states_type::PaddedStates)(mat::AbstractMatrix)
279+
results = states_type.(eachcol(mat))
280+
return hcat(results...)
281+
end
105282

106-
## dispatch over matrices for all nonlin algorithms
283+
function (states_type::PaddedStates)(vect::AbstractVector)
284+
tt = eltype(vect)
285+
return vcat(vect, tt(states_type.padding))
286+
end
287+
288+
#### non linear algorithms ###
289+
## to conform to current (0.10.5) approach
290+
nla(nlat::NonLinearAlgorithm, x_old::AbstractVecOrMat) = nlat(x_old)
291+
292+
# dispatch over matrices for all nonlin algorithms
107293
function (nlat::NonLinearAlgorithm)(x_old::AbstractMatrix)
108294
results = nlat.(eachcol(x_old))
109295
return hcat(results...)
@@ -278,17 +464,6 @@ function (::NLAT1)(x_old::AbstractVector)
278464
return x_new
279465
end
280466

281-
function nla(::NLAT1, x_old)
282-
x_new = copy(x_old)
283-
for i in 1:size(x_new, 1)
284-
if mod(i, 2) != 0
285-
x_new[i, :] = copy(x_old[i, :] .* x_old[i, :])
286-
end
287-
end
288-
289-
return x_new
290-
end
291-
292467
@doc raw"""
293468
NLAT2()
294469

0 commit comments

Comments
 (0)