1+ @doc raw """
2+
3+ StandardRidge([Type], [reg])
4+
5+ Returns a training method for `train` based on ridge regression.
6+ The equations for ridge regression are as follows:
7+
8+ ```math
9+ \m athbf{w} = (\m athbf{X}^\t op \m athbf{X} +
10+ \l ambda \m athbf{I})^{-1} \m athbf{X}^\t op \m athbf{y}
11+ ```
12+
13+ # Arguments
14+ - `Type`: type of the regularization argument. Default is inferred internally,
15+ there's usually no need to tweak this
16+ - `reg`: regularization coefficient. Default is set to 0.0 (linear regression).
17+
18+ # Examples
19+ ```jldoctest
20+ julia> ridge_reg = StandardRidge()
21+ StandardRidge(0.0)
22+
23+ julia> ol = train(ridge_reg, rand(Float32, 10, 10), rand(Float32, 10, 10))
24+ OutputLayer successfully trained with output size: 10
25+
26+ julia> ol.output_matrix #visualize output matrix
27+ 10×10 Matrix{Float32}:
28+ 0.456574 -0.0407612 0.121963 … 0.859327 -0.127494 0.0572494
29+ 0.133216 -0.0337922 0.0185378 0.24077 0.0297829 0.31512
30+ 0.379672 -1.24541 -0.444314 1.02269 -0.0446086 0.482282
31+ 1.18455 -0.517971 -0.133498 0.84473 0.31575 0.205857
32+ -0.119345 0.563294 0.747992 0.0102919 1.509 -0.328005
33+ -0.0716812 0.0976365 0.628654 … -0.516041 2.4309 -0.113402
34+ 0.0153872 -0.52334 0.0526867 0.729326 2.98958 1.32703
35+ 0.154027 0.6013 1.05548 -0.0840203 0.991182 -0.328555
36+ 1.11007 -0.0371736 -0.0529418 0.186796 -1.21815 0.204838
37+ 0.282996 -0.263799 0.132079 0.875417 0.497951 0.273423
38+
39+ julia> ridge_reg = StandardRidge(0.001) #passing a value
40+ StandardRidge(0.001)
41+
42+ julia> ol = train(ridge_reg, rand(Float16, 10, 10), rand(Float16, 10, 10))
43+ OutputLayer successfully trained with output size: 10
44+
45+ julia> ol.output_matrix
46+ 10×10 Matrix{Float16}:
47+ -1.251 3.074 -1.566 -0.10297 … 0.3823 1.341 -1.77 -0.445
48+ 0.11017 -2.027 0.8975 0.872 -0.643 0.02615 1.083 0.615
49+ 0.2634 3.514 -1.168 -1.532 1.486 0.1255 -1.795 -0.06555
50+ 0.964 0.9463 -0.006855 -0.519 0.0743 -0.181 -0.433 0.06793
51+ -0.389 1.887 -0.702 -0.8906 0.221 1.303 -1.318 0.2634
52+ -0.1337 -0.4453 -0.06866 0.557 … -0.322 0.247 0.2554 0.5933
53+ -0.6724 0.906 -0.547 0.697 -0.2664 0.809 -0.6836 0.2358
54+ 0.8843 -3.664 1.615 1.417 -0.6094 -0.59 1.975 0.4785
55+ 1.266 -0.933 0.0664 -0.4497 -0.0759 -0.03897 1.117 0.3152
56+ 0.6353 1.327 -0.6978 -1.053 0.8037 0.6577 -0.7246 0.07336
57+
58+ ```
59+ """
160struct StandardRidge
261 reg:: Number
362end
@@ -10,13 +69,23 @@ function StandardRidge()
1069 return StandardRidge (0.0 )
1170end
1271
13- function train (sr:: StandardRidge ,
14- states,
15- target_data)
72+ function train (sr:: StandardRidge , states:: AbstractArray , target_data:: AbstractArray )
1673 # A = states * states' + sr.reg * I
1774 # b = states * target_data
1875 # output_layer = (A \ b)'
19- output_layer = Matrix (((states * states' + sr. reg * I) \
76+
77+ if size (states, 2 ) != size (target_data, 2 )
78+ throw (DimensionMismatch (" \n " *
79+ " \n " *
80+ " - Number of columns in `states`: $(size (states, 2 )) \n " *
81+ " - Number of columns in `target_data`: $(size (target_data, 2 )) \n " *
82+ " The dimensions of `states` and `target_data` must align for training." *
83+ " \n "
84+ ))
85+ end
86+
87+ T = eltype (states)
88+ output_layer = Matrix (((states * states' + T (sr. reg) * I) \
2089 (states * target_data' ))' )
2190 return OutputLayer (sr, output_layer, size (target_data, 1 ), target_data[:, end ])
2291end
0 commit comments