Skip to content

Commit 2d54185

Browse files
ordering states.jl
1 parent 891998f commit 2d54185

File tree

1 file changed

+42
-44
lines changed

1 file changed

+42
-44
lines changed

src/states.jl

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ julia> new_mat = states(test_mat)
5959
"""
6060
struct StandardStates <: AbstractStates end
6161

62+
function (::StandardStates)(nla_type::NonLinearAlgorithm,
63+
state, inp)
64+
return nla(nla_type, state)
65+
end
66+
67+
(::StandardStates)(state) = state
6268
"""
6369
ExtendedStates()
6470
@@ -113,12 +119,21 @@ julia> new_mat = states(test_mat, fill(3.0f0, 3))
113119
"""
114120
struct ExtendedStates <: AbstractStates end
115121

116-
struct PaddedStates{T} <: AbstractPaddedStates
117-
padding::T
122+
function (states_type::ExtendedStates)(mat::AbstractMatrix, inp::AbstractVector)
123+
results = Vector{Vector{eltype(mat)}}(undef, size(mat, 2))
124+
for (idx, col) in enumerate(eachcol(mat))
125+
results[idx] = states_type(col, inp)
126+
end
127+
return hcat(results...)
118128
end
119129

120-
struct PaddedExtendedStates{T} <: AbstractPaddedStates
121-
padding::T
130+
function (::ExtendedStates)(vect::AbstractVector, inp::AbstractVector)
131+
return x_tmp = vcat(vect, inp)
132+
end
133+
134+
function (states_type::ExtendedStates)(nla_type::NonLinearAlgorithm,
135+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
136+
return nla(nla_type, states_type(state, inp))
122137
end
123138

124139
"""
@@ -170,10 +185,29 @@ julia> new_mat = states(test_mat)
170185
1.0 1.0 1.0 1.0 1.0
171186
```
172187
"""
188+
struct PaddedStates{T} <: AbstractPaddedStates
189+
padding::T
190+
end
191+
173192
function PaddedStates(; padding=1.0)
174193
return PaddedStates(padding)
175194
end
176195

196+
function (states_type::PaddedStates)(mat::AbstractMatrix)
197+
results = states_type.(eachcol(mat))
198+
return hcat(results...)
199+
end
200+
201+
function (states_type::PaddedStates)(vect::AbstractVector)
202+
tt = eltype(vect)
203+
return vcat(vect, tt(states_type.padding))
204+
end
205+
206+
function (states_type::PaddedStates)(nla_type::NonLinearAlgorithm,
207+
state::AbstractVecOrMat, inp::AbstractVecOrMat)
208+
return nla(nla_type, states_type(state))
209+
end
210+
177211
"""
178212
PaddedExtendedStates(padding)
179213
PaddedExtendedStates(;padding=1.0)
@@ -229,26 +263,12 @@ julia> new_mat = states(test_mat, fill(3.0f0, 3))
229263
3.0 3.0 3.0 3.0 3.0
230264
```
231265
"""
232-
function PaddedExtendedStates(; padding=1.0)
233-
return PaddedExtendedStates(padding)
234-
end
235-
236-
#functions of the states to apply modifications
237-
function (::StandardStates)(nla_type::NonLinearAlgorithm,
238-
state, inp)
239-
return nla(nla_type, state)
240-
end
241-
242-
(::StandardStates)(state) = state
243-
244-
function (states_type::ExtendedStates)(nla_type::NonLinearAlgorithm,
245-
state::AbstractVecOrMat, inp::AbstractVecOrMat)
246-
return nla(nla_type, states_type(state, inp))
266+
struct PaddedExtendedStates{T} <: AbstractPaddedStates
267+
padding::T
247268
end
248269

249-
function (states_type::PaddedStates)(nla_type::NonLinearAlgorithm,
250-
state::AbstractVecOrMat, inp::AbstractVecOrMat)
251-
return nla(nla_type, states_type(state))
270+
function PaddedExtendedStates(; padding=1.0)
271+
return PaddedExtendedStates(padding)
252272
end
253273

254274
function (states_type::PaddedExtendedStates)(nla_type::NonLinearAlgorithm,
@@ -263,28 +283,6 @@ function (states_type::PaddedExtendedStates)(state::AbstractVecOrMat,
263283
return x_ext
264284
end
265285

266-
function (states_type::ExtendedStates)(mat::AbstractMatrix, inp::AbstractVector)
267-
results = Vector{Vector{eltype(mat)}}(undef, size(mat, 2))
268-
for (idx, col) in enumerate(eachcol(mat))
269-
results[idx] = states_type(col, inp)
270-
end
271-
return hcat(results...)
272-
end
273-
274-
function (::ExtendedStates)(vect::AbstractVector, inp::AbstractVector)
275-
return x_tmp = vcat(vect, inp)
276-
end
277-
278-
function (states_type::PaddedStates)(mat::AbstractMatrix)
279-
results = states_type.(eachcol(mat))
280-
return hcat(results...)
281-
end
282-
283-
function (states_type::PaddedStates)(vect::AbstractVector)
284-
tt = eltype(vect)
285-
return vcat(vect, tt(states_type.padding))
286-
end
287-
288286
#### non linear algorithms ###
289287
## to conform to current (0.10.5) approach
290288
nla(nlat::NonLinearAlgorithm, x_old::AbstractVecOrMat) = nlat(x_old)

0 commit comments

Comments
 (0)