Skip to content

Commit ae96036

Browse files
qr for ridge regression
1 parent 7c986f3 commit ae96036

File tree

2 files changed

+7
-17
lines changed

2 files changed

+7
-17
lines changed

src/ReservoirComputing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ReservoirComputing
33
using Adapt: adapt
44
using CellularAutomata: CellularAutomaton
55
using Compat: @compat
6-
using LinearAlgebra: eigvals, mul!, I
6+
using LinearAlgebra: eigvals, mul!, I, qr
77
using NNlib: fast_act, sigmoid
88
using Random: Random, AbstractRNG
99
using Reexport: Reexport, @reexport

src/train/linear_regression.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,11 @@ function StandardRidge()
3030
end
3131

3232
function train(sr::StandardRidge, states::AbstractArray, target_data::AbstractArray)
33-
#A = states * states' + sr.reg * I
34-
#b = states * target_data
35-
#output_layer = (A \ b)'
36-
37-
if size(states, 2) != size(target_data, 2)
38-
throw(DimensionMismatch("\n" *
39-
"\n" *
40-
" - Number of columns in `states`: $(size(states, 2))\n" *
41-
" - Number of columns in `target_data`: $(size(target_data, 2))\n" *
42-
"The dimensions of `states` and `target_data` must align for training." *
43-
"\n"
44-
))
45-
end
46-
47-
output_layer = Matrix(((states * states' + sr.reg * I) \
48-
(states * target_data'))')
33+
n_states = size(states, 1)
34+
A = [states'; sqrt(sr.reg) * I(n_states)]
35+
b = [target_data'; zeros(n_states, size(target_data, 1))]
36+
F = qr(A)
37+
Wt = F \ b
38+
output_layer = Matrix(Wt')
4939
return OutputLayer(sr, output_layer, size(target_data, 1), target_data[:, end])
5040
end

0 commit comments

Comments
 (0)