@@ -2,10 +2,10 @@ abstract type AbstractReservoirCollectionLayer <: AbstractLuxLayer end
22abstract type AbstractReservoirRecurrentCell <: AbstractLuxLayer end
33abstract type AbstractReservoirTrainableLayer <: AbstractLuxLayer end
44
5- # ## Readout
5+ # ## LinearReadout
66# adapted from lux layers/basic Dense
77@doc raw """
8- Readout (in_dims => out_dims, [activation];
8+ LinearReadout (in_dims => out_dims, [activation];
99 use_bias=false, include_collect=true)
1010
1111Linear readout layer with optional bias and elementwise activation. Intended as
@@ -48,7 +48,7 @@ before this layer (logically inserting a [`Collect()`](@ref) right before it).
4848 Otherwise training may operate on the post-readout signal,
4949 which is usually unintended.
5050"""
51- @concrete struct Readout <: AbstractReservoirTrainableLayer
51+ @concrete struct LinearReadout <: AbstractReservoirTrainableLayer
5252 activation
5353 in_dims <: IntegerType
5454 out_dims <: IntegerType
@@ -58,32 +58,32 @@ before this layer (logically inserting a [`Collect()`](@ref) right before it).
5858 include_collect <: StaticBool
5959end
6060
61- function Readout (mapping:: Pair{<:IntegerType,<:IntegerType} , activation= identity; kwargs... )
62- return Readout (first (mapping), last (mapping), activation; kwargs... )
61+ function LinearReadout (mapping:: Pair{<:IntegerType,<:IntegerType} , activation= identity; kwargs... )
62+ return LinearReadout (first (mapping), last (mapping), activation; kwargs... )
6363end
6464
65- function Readout (in_dims:: IntegerType , out_dims:: IntegerType , activation= identity;
65+ function LinearReadout (in_dims:: IntegerType , out_dims:: IntegerType , activation= identity;
6666 init_weight= rand32, init_bias= rand32, include_collect:: BoolType = True (),
6767 use_bias:: BoolType = False ())
68- return Readout (activation, in_dims, out_dims, init_weight, init_bias, static (use_bias), static (include_collect))
68+ return LinearReadout (activation, in_dims, out_dims, init_weight, init_bias, static (use_bias), static (include_collect))
6969end
7070
71- function initialparameters (rng:: AbstractRNG , ro:: Readout )
71+ function initialparameters (rng:: AbstractRNG , ro:: LinearReadout )
7272 weight = ro. init_weight (rng, ro. out_dims, ro. in_dims)
7373
7474 if has_bias (ro)
75- return (; weight, bias= ro. init_bias (rng, Float32, ro. out_dims))
75+ return (; weight, bias= ro. init_bias (rng, ro. out_dims))
7676 else
7777 return (; weight)
7878 end
7979end
8080
81- parameterlength (ro:: Readout ) = ro. out_dims * ro. in_dims + has_bias (ro) * ro. out_dims
82- statelength (ro:: Readout ) = 0
81+ parameterlength (ro:: LinearReadout ) = ro. out_dims * ro. in_dims + has_bias (ro) * ro. out_dims
82+ statelength (ro:: LinearReadout ) = 0
8383
84- outputsize (ro:: Readout , _, :: AbstractRNG ) = (ro. out_dims,)
84+ outputsize (ro:: LinearReadout , _, :: AbstractRNG ) = (ro. out_dims,)
8585
86- function (ro:: Readout )(inp:: AbstractArray , ps, st:: NamedTuple )
86+ function (ro:: LinearReadout )(inp:: AbstractArray , ps, st:: NamedTuple )
8787 out_tmp = ps. weight * inp
8888 if has_bias (ro)
8989 out_tmp += ps. bias
@@ -92,8 +92,8 @@ function (ro::Readout)(inp::AbstractArray, ps, st::NamedTuple)
9292 return output, st
9393end
9494
95- function Base. show (io:: IO , ro:: Readout )
96- print (io, " Readout ($(ro. in_dims) => $(ro. out_dims) " )
95+ function Base. show (io:: IO , ro:: LinearReadout )
96+ print (io, " LinearReadout ($(ro. in_dims) => $(ro. out_dims) " )
9797 (ro. activation == identity) || print (io, " , $(ro. activation) " )
9898 has_bias (ro) || print (io, " , use_bias=false" )
9999 ic = known (getproperty (ro, Val (:include_collect )))
@@ -136,7 +136,7 @@ vectors are concatenated with `vcat` in order of appearance.
136136
137137## Notes
138138
139- - When used with a single `Collect()` before a [`Readout `](@ref), training uses exactly
139+ - When used with a single `Collect()` before a [`LinearReadout `](@ref), training uses exactly
140140 the tensor right before the readout (e.g., the reservoir state).
141141- With **multiple** `Collect()` layers (e.g., after different submodules), the
142142 per-step features are `vcat`-ed in chain order to form one feature vector.
@@ -150,7 +150,7 @@ vectors are concatenated with `vcat` in order of appearance.
150150 StatefulLayer(ESNCell(3 => 300)),
151151 NLAT2(),
152152 Collect(), # <-- collect the 300-dim reservoir after NLAT2
153- Readout (300 => 3; include_collect=false) # <-- toggle off the default Collect()
153+ LinearReadout (300 => 3; include_collect=false) # <-- toggle off the default Collect()
154154 )
155155```
156156"""
@@ -173,7 +173,7 @@ in a step, the feature defaults to the final vector exiting the chain for
173173that time step.
174174
175175!!! note
176- If your [`Readout `](@ref) layer was created with `include_collect=true`
176+ If your [`LinearReadout `](@ref) layer was created with `include_collect=true`
177177 (default behaviour), a collection point is placed immediately before the readout,
178178 so the collected features are the inputs to the readout.
179179
@@ -209,7 +209,12 @@ function collectstates(rc::AbstractLuxLayer, data::AbstractMatrix, ps, st::Named
209209 end
210210 push! (collected, state_vec === nothing ? copy (inp_tmp) : state_vec)
211211 end
212- states = eltype (data).(reduce (hcat, collected))
212+ @assert ! isempty (collected)
213+ firstcol = collected[1 ]
214+ Tcol = eltype (firstcol)
215+ empty_mat = zeros (Tcol, length (firstcol), 0 )
216+ states_raw = reduce (hcat, collected; init= empty_mat)
217+ states = eltype (data).(states_raw)
213218 return states, newst
214219end
215220
0 commit comments