diff --git a/src/learning/optim/grad_desc.rs b/src/learning/optim/grad_desc.rs index 1e114877..c1e010c9 100644 --- a/src/learning/optim/grad_desc.rs +++ b/src/learning/optim/grad_desc.rs @@ -107,6 +107,8 @@ pub struct StochasticGD { mu: f64, /// The number of passes through the data. iters: usize, + /// Use Nesterove momentum or not + nesterove_momentum: bool, } /// The default Stochastic GD algorithm. @@ -116,12 +118,14 @@ pub struct StochasticGD { /// - alpha = 0.1 /// - mu = 0.1 /// - iters = 20 +/// - nestorove = false impl Default for StochasticGD { fn default() -> StochasticGD { StochasticGD { alpha: 0.1, mu: 0.1, iters: 20, + nesterove_momentum: false, } } } @@ -131,8 +135,6 @@ impl StochasticGD { /// /// Requires the learning rate, momentum rate and iteration count /// to be specified. - /// - /// With Nesterov momentum by default. /// /// # Examples /// @@ -149,8 +151,23 @@ impl StochasticGD { alpha: alpha, mu: mu, iters: iters, + nesterove_momentum: false, } } + + /// Enable Nesterove momentum for stochastic gradient descent algorithm. + /// + /// # Examples + /// + /// ``` + /// use rusty_machine::learning::optim::grad_desc::StochasticGD; + /// + /// let sgd = StochasticGD::new(0.1, 0.3, 5).with_nesterove_momentum(); + /// ``` + pub fn with_nesterove_momentum(mut self) -> StochasticGD { + self.nesterove_momentum = true; + self + } } impl OptimAlgorithm for StochasticGD @@ -184,13 +201,21 @@ impl OptimAlgorithm for StochasticGD &inputs.select_rows(&[*i]), &targets.select_rows(&[*i])); - // Backup previous velocity - let prev_w = delta_w.clone(); - // Compute the difference in gradient using Nesterov momentum - delta_w = Vector::new(vec_data) * self.mu + &delta_w * self.alpha; - // Update the parameters - optimizing_val = &optimizing_val - - (&prev_w * (-self.alpha) + &delta_w * (1. + self.alpha)); + if self.nesterove_momentum { + // Backup previous velocity + let prev_w = delta_w.clone(); + // Compute the difference in gradient using Nesterov momentum + delta_w = Vector::new(vec_data) * self.mu + &delta_w * self.alpha; + // Update the parameters + optimizing_val = &optimizing_val - + (&prev_w * (-self.alpha) + &delta_w * (1. + self.alpha)); + } else { + // Compute the difference in gradient using momentum + delta_w = Vector::new(vec_data) * self.mu + &delta_w * self.alpha; + // Update the parameters + optimizing_val = &optimizing_val - &delta_w * self.mu; + } + // Set the end cost (this is only used after the last iteration) end_cost += cost; } diff --git a/tests/learning/optim/grad_desc.rs b/tests/learning/optim/grad_desc.rs index 9dd1281a..31b0890a 100644 --- a/tests/learning/optim/grad_desc.rs +++ b/tests/learning/optim/grad_desc.rs @@ -58,7 +58,22 @@ fn convex_gd_training() { fn convex_stochastic_gd_training() { let x_sq = XSqModel { c: 20f64 }; - let gd = StochasticGD::new(0.9f64, 0.1f64, 100); + let gd = StochasticGD::new(0.5f64, 1f64, 100); + let test_data = vec![100f64]; + let params = gd.optimize(&x_sq, + &test_data[..], + &Matrix::zeros(100, 1), + &Matrix::zeros(100, 1)); + + assert!(params[0] - 20f64 < 1e-10); + assert!(x_sq.compute_grad(¶ms, &Matrix::zeros(1, 1), &Matrix::zeros(1, 1)).0 < 1e-10); +} + +#[test] +fn convex_stochastic_gd_nesterove_momentum_training() { + let x_sq = XSqModel { c: 20f64 }; + + let gd = StochasticGD::new(0.9f64, 0.1f64, 100).with_nesterove_momentum(); let test_data = vec![100f64]; let params = gd.optimize(&x_sq, &test_data[..],