@@ -22,8 +22,8 @@ specified reservoir driver.
2222 update.
2323"""
2424function create_states (reservoir_driver:: AbstractReservoirDriver ,
25- train_data:: AbstractArray , washout:: Int , reservoir_matrix:: AbstractMatrix ,
26- input_matrix:: AbstractMatrix , bias_vector:: AbstractArray )
25+ train_data:: AbstractArray{T,2} , washout:: Int , reservoir_matrix:: AbstractMatrix ,
26+ input_matrix:: AbstractMatrix , bias_vector:: AbstractArray ) where {T <: Number }
2727 train_len = size (train_data, 2 ) - washout
2828 res_size = size (reservoir_matrix, 1 )
2929 states = adapt (typeof (train_data), zeros (res_size, train_len))
@@ -32,6 +32,7 @@ function create_states(reservoir_driver::AbstractReservoirDriver,
3232
3333 for i in 1 : washout
3434 yv = @view train_data[:, i]
35+ @show typeof (yv)
3536 _state = next_state! (_state, reservoir_driver, _state, yv, reservoir_matrix,
3637 input_matrix, bias_vector, tmp_array)
3738 end
@@ -47,8 +48,8 @@ function create_states(reservoir_driver::AbstractReservoirDriver,
4748end
4849
4950function create_states (reservoir_driver:: AbstractReservoirDriver ,
50- train_data:: AbstractArray , washout:: Int , reservoir_matrix:: Vector ,
51- input_matrix:: AbstractArray , bias_vector:: AbstractArray )
51+ train_data:: AbstractArray{T,2} , washout:: Int , reservoir_matrix:: Vector ,
52+ input_matrix:: AbstractArray , bias_vector:: AbstractArray ) where {T <: Number }
5253 train_len = size (train_data, 2 ) - washout
5354 res_size = sum ([size (reservoir_matrix[i], 1 ) for i in 1 : length (reservoir_matrix)])
5455 states = adapt (typeof (train_data), zeros (res_size, train_len))
@@ -357,14 +358,19 @@ function obtain_gru_state!(out, variant::FullyGated, gru, x, y, W, W_in, b, tmp_
357358end
358359
359360# minimal
361+ #=
360362function obtain_gru_state!(out, variant::Minimal, gru, x, y, W, W_in, b, tmp_array)
361363 mul!(tmp_array[1], gru.Wz_in, y)
362364 mul!(tmp_array[2], gru.Wz, x)
363365 @. tmp_array[3] = gru.activation_function[1](tmp_array[1] + tmp_array[2] + gru.bz)
364366
367+ mul!(tmp_array[4], gru.Wr_in, y)
368+ mul!(tmp_array[5], gru.Wr, x)
369+ @. tmp_array[6] = gru.activation_function[2](tmp_array[4] + tmp_array[5] + gru.br)
370+
365371 mul!(tmp_array[7], W_in, y)
366372 mul!(tmp_array[8], W, tmp_array[6] .* x)
367- @. tmp_array[9 ] = gru. activation_function[2 ](tmp_array[7 ] + tmp_array[8 ] + b)
368-
369- return @. out = (1 - tmp_array[3 ]) * x + tmp_array[3 ] * tmp_array[6 ]
373+ @. tmp_array[9] = gru.activation_function[3](tmp_array[7] + tmp_array[8] + b)
374+ return @. out = (1 - tmp_array[3]) * x + tmp_array[3] * tmp_array[9]
370375end
376+ =#
0 commit comments