11use rand:: Rng ;
22use serde:: { Serialize , Deserialize } ;
3- use crate :: optimizer:: { SGD , Adam } ;
3+ use crate :: optimizer:: { Adam , Sgd } ;
44
55/// Weight initialization strategy
66/// - Xavier (Glorot): For layers followed by Sigmoid/Softmax. std = sqrt(2 / (n_in + n_out))
@@ -16,14 +16,14 @@ pub enum InitStrategy {
1616struct Softmax ;
1717
1818impl Softmax {
19- fn forward ( logits : & Vec < f32 > ) -> Vec < f32 > {
19+ fn forward ( logits : & [ f32 ] ) -> Vec < f32 > {
2020 let max_val = logits. iter ( ) . cloned ( ) . fold ( f32:: NEG_INFINITY , f32:: max) ;
2121 let exp_vals: Vec < f32 > = logits. iter ( ) . map ( |x| ( x - max_val) . exp ( ) ) . collect ( ) ;
2222 let sum_exp: f32 = exp_vals. iter ( ) . sum ( ) ;
2323 exp_vals. iter ( ) . map ( |x| x / sum_exp) . collect ( )
2424 }
2525
26- fn backward ( preds : & Vec < Vec < f32 > > , y : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
26+ fn backward ( preds : & [ Vec < f32 > ] , y : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
2727 preds. iter ( ) . zip ( y. iter ( ) )
2828 . map ( |( p, t) | p. iter ( ) . zip ( t. iter ( ) ) . map ( |( a, b) | a - b) . collect ( ) )
2929 . collect ( )
@@ -33,13 +33,13 @@ impl Softmax {
3333/// Unified trait for all model layers
3434pub trait Layer {
3535 /// Forward pass through the layer
36- fn forward ( & mut self , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > ;
36+ fn forward ( & mut self , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > ;
3737
3838 /// Backward pass through the layer (backpropagation)
39- fn backward ( & mut self , grad_output : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > ;
39+ fn backward ( & mut self , grad_output : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > ;
4040
41- /// Update layer parameters using SGD
42- fn update_sgd ( & mut self , _optimizer : & SGD ) { }
41+ /// Update layer parameters using Sgd
42+ fn update_sgd ( & mut self , _optimizer : & Sgd ) { }
4343
4444 /// Update layer parameters using Adam
4545 fn update_adam ( & mut self , _optimizer : & mut Adam ) { }
@@ -112,8 +112,8 @@ impl Linear {
112112 ( -2.0 * u1. ln ( ) ) . sqrt ( ) * ( 2.0 * std:: f32:: consts:: PI * u2) . cos ( )
113113 }
114114
115- pub fn forward ( & mut self , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
116- self . cached_input = input. clone ( ) ;
115+ pub fn forward ( & mut self , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
116+ self . cached_input = input. to_owned ( ) ;
117117
118118 input. iter ( ) . map ( |x| {
119119 self . weights . iter ( ) . enumerate ( ) . map ( |( i, w) | {
@@ -122,26 +122,34 @@ impl Linear {
122122 } ) . collect ( )
123123 }
124124
125- pub fn backward ( & mut self , grad_output : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
125+ pub fn backward ( & mut self , grad_output : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
126126 let batch_size = self . cached_input . len ( ) ;
127127 let input_size = self . cached_input [ 0 ] . len ( ) ;
128128
129129 let mut grad_input = vec ! [ vec![ 0.0 ; input_size] ; batch_size] ;
130130
131- for i in 0 ..batch_size {
132- for j in 0 ..self . weights . len ( ) {
133- self . grad_bias [ j] += grad_output[ i] [ j] ;
134- for k in 0 ..input_size {
135- self . grad_weights [ j] [ k] += grad_output[ i] [ j] * self . cached_input [ i] [ k] ;
136- grad_input[ i] [ k] += grad_output[ i] [ j] * self . weights [ j] [ k] ;
131+ for ( i, grad_in_row) in grad_input. iter_mut ( ) . enumerate ( ) {
132+ let go_row = & grad_output[ i] ;
133+ let cached_row = & self . cached_input [ i] ;
134+ for ( j, ( grad_w_row, weight_row) ) in self
135+ . grad_weights
136+ . iter_mut ( )
137+ . zip ( self . weights . iter ( ) )
138+ . enumerate ( )
139+ {
140+ let g = go_row[ j] ;
141+ self . grad_bias [ j] += g;
142+ for ( k, ( gw, w) ) in grad_w_row. iter_mut ( ) . zip ( weight_row. iter ( ) ) . enumerate ( ) {
143+ * gw += g * cached_row[ k] ;
144+ grad_in_row[ k] += g * w;
137145 }
138146 }
139147 }
140148 grad_input
141149 }
142150
143- /// Update weights using SGD
144- pub fn update_sgd ( & mut self , optimizer : & SGD ) {
151+ /// Update weights using Sgd
152+ pub fn update_sgd ( & mut self , optimizer : & Sgd ) {
145153 for i in 0 ..self . weights . len ( ) {
146154 for j in 0 ..self . weights [ 0 ] . len ( ) {
147155 // L2 regularization: add lambda * weight to gradient
@@ -172,15 +180,15 @@ impl Linear {
172180}
173181
174182impl Layer for Linear {
175- fn forward ( & mut self , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
183+ fn forward ( & mut self , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
176184 self . forward ( input)
177185 }
178186
179- fn backward ( & mut self , grad_output : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
187+ fn backward ( & mut self , grad_output : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
180188 self . backward ( grad_output)
181189 }
182190
183- fn update_sgd ( & mut self , optimizer : & SGD ) {
191+ fn update_sgd ( & mut self , optimizer : & Sgd ) {
184192 self . update_sgd ( optimizer)
185193 }
186194
@@ -196,11 +204,11 @@ impl Layer for Linear {
196204pub struct ReLU ;
197205
198206impl ReLU {
199- pub fn forward ( input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
207+ pub fn forward ( input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
200208 input. iter ( ) . map ( |x| x. iter ( ) . map ( |& v| v. max ( 0.0 ) ) . collect ( ) ) . collect ( )
201209 }
202210
203- pub fn backward ( grad : & Vec < Vec < f32 > > , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
211+ pub fn backward ( grad : & [ Vec < f32 > ] , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
204212 grad. iter ( ) . zip ( input. iter ( ) )
205213 . map ( |( g, i) | g. iter ( ) . zip ( i. iter ( ) ) . map ( |( g, v) | if * v > 0.0 { * g } else { 0.0 } ) . collect ( ) )
206214 . collect ( )
@@ -217,13 +225,13 @@ pub struct LeakyReLU;
217225impl LeakyReLU {
218226 const ALPHA : f32 = 0.01 ;
219227
220- pub fn forward ( input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
228+ pub fn forward ( input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
221229 input. iter ( ) . map ( |x| {
222230 x. iter ( ) . map ( |& v| if v > 0.0 { v } else { Self :: ALPHA * v } ) . collect ( )
223231 } ) . collect ( )
224232 }
225233
226- pub fn backward ( grad : & Vec < Vec < f32 > > , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
234+ pub fn backward ( grad : & [ Vec < f32 > ] , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
227235 grad. iter ( ) . zip ( input. iter ( ) )
228236 . map ( |( g, i) | g. iter ( ) . zip ( i. iter ( ) )
229237 . map ( |( g, v) | if * v > 0.0 { * g } else { Self :: ALPHA * * g } )
@@ -239,13 +247,13 @@ impl LeakyReLU {
239247pub struct Sigmoid ;
240248
241249impl Sigmoid {
242- pub fn forward ( input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
250+ pub fn forward ( input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
243251 input. iter ( ) . map ( |x| {
244252 x. iter ( ) . map ( |& v| 1.0 / ( 1.0 + ( -v) . exp ( ) ) ) . collect ( )
245253 } ) . collect ( )
246254 }
247255
248- pub fn backward ( grad : & Vec < Vec < f32 > > , out : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
256+ pub fn backward ( grad : & [ Vec < f32 > ] , out : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
249257 grad. iter ( ) . zip ( out. iter ( ) )
250258 . map ( |( g, o) | g. iter ( ) . zip ( o. iter ( ) ) . map ( |( g_val, o_val) | g_val * o_val * ( 1.0 - o_val) ) . collect ( ) )
251259 . collect ( )
@@ -263,15 +271,15 @@ pub enum Activation {
263271}
264272
265273impl Activation {
266- pub fn forward ( & self , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
274+ pub fn forward ( & self , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
267275 match self {
268276 Self :: ReLU => ReLU :: forward ( input) ,
269277 Self :: LeakyReLU => LeakyReLU :: forward ( input) ,
270278 Self :: Sigmoid => Sigmoid :: forward ( input) ,
271279 }
272280 }
273281
274- pub fn backward ( & self , grad : & Vec < Vec < f32 > > , cache : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
282+ pub fn backward ( & self , grad : & [ Vec < f32 > ] , cache : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
275283 match self {
276284 Self :: ReLU => ReLU :: backward ( grad, cache) ,
277285 Self :: LeakyReLU => LeakyReLU :: backward ( grad, cache) ,
@@ -307,17 +315,17 @@ impl ActivationLayer {
307315}
308316
309317impl Layer for ActivationLayer {
310- fn forward ( & mut self , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
318+ fn forward ( & mut self , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
311319 let output = self . activation . forward ( input) ;
312320 // Sigmoid backward needs output, ReLU/LeakyReLU need input
313321 match self . activation {
314322 Activation :: Sigmoid => self . cached_data = output. clone ( ) ,
315- _ => self . cached_data = input. clone ( ) ,
323+ _ => self . cached_data = input. to_owned ( ) ,
316324 }
317325 output
318326 }
319327
320- fn backward ( & mut self , grad_output : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
328+ fn backward ( & mut self , grad_output : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
321329 self . activation . backward ( grad_output, & self . cached_data )
322330 }
323331}
@@ -335,13 +343,13 @@ impl SoftmaxLayer {
335343}
336344
337345impl Layer for SoftmaxLayer {
338- fn forward ( & mut self , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
346+ fn forward ( & mut self , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
339347 let output = input. iter ( ) . map ( |l| Softmax :: forward ( l) ) . collect ( ) ;
340348 self . cached_output = output;
341349 self . cached_output . clone ( )
342350 }
343351
344- fn backward ( & mut self , grad_output : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
352+ fn backward ( & mut self , grad_output : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
345353 // Note: For the output layer, grad_output is expected to be the targets (Y)
346354 // due to the specific implementation of Softmax::backward returning (preds - targets)
347355 Softmax :: backward ( & self . cached_output , grad_output)
@@ -358,23 +366,23 @@ pub enum LayerWrapper {
358366}
359367
360368impl Layer for LayerWrapper {
361- fn forward ( & mut self , input : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
369+ fn forward ( & mut self , input : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
362370 match self {
363371 Self :: Linear ( l) => l. forward ( input) ,
364372 Self :: Activation ( a) => a. forward ( input) ,
365373 Self :: Softmax ( s) => s. forward ( input) ,
366374 }
367375 }
368376
369- fn backward ( & mut self , grad_output : & Vec < Vec < f32 > > ) -> Vec < Vec < f32 > > {
377+ fn backward ( & mut self , grad_output : & [ Vec < f32 > ] ) -> Vec < Vec < f32 > > {
370378 match self {
371379 Self :: Linear ( l) => l. backward ( grad_output) ,
372380 Self :: Activation ( a) => a. backward ( grad_output) ,
373381 Self :: Softmax ( s) => s. backward ( grad_output) ,
374382 }
375383 }
376384
377- fn update_sgd ( & mut self , optimizer : & SGD ) {
385+ fn update_sgd ( & mut self , optimizer : & Sgd ) {
378386 if let Self :: Linear ( l) = self {
379387 l. update_sgd ( optimizer) ;
380388 }
0 commit comments