Skip to content
This repository was archived by the owner on Jul 16, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions src/learning/optim/grad_desc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ pub struct StochasticGD {
mu: f64,
/// The number of passes through the data.
iters: usize,
/// Use Nesterove momentum or not
nesterove_momentum: bool,
}

/// The default Stochastic GD algorithm.
Expand All @@ -116,12 +118,14 @@ pub struct StochasticGD {
/// - alpha = 0.1
/// - mu = 0.1
/// - iters = 20
/// - nestorove = false
impl Default for StochasticGD {
fn default() -> StochasticGD {
StochasticGD {
alpha: 0.1,
mu: 0.1,
iters: 20,
nesterove_momentum: false,
}
}
}
Expand All @@ -131,8 +135,6 @@ impl StochasticGD {
///
/// Requires the learning rate, momentum rate and iteration count
/// to be specified.
///
/// With Nesterov momentum by default.
///
/// # Examples
///
Expand All @@ -149,8 +151,23 @@ impl StochasticGD {
alpha: alpha,
mu: mu,
iters: iters,
nesterove_momentum: false,
}
}

/// Enable Nesterove momentum for stochastic gradient descent algorithm.
///
/// # Examples
///
/// ```
/// use rusty_machine::learning::optim::grad_desc::StochasticGD;
///
/// let sgd = StochasticGD::new(0.1, 0.3, 5).with_nesterove_momentum();
/// ```
pub fn with_nesterove_momentum(mut self) -> StochasticGD {
self.nesterove_momentum = true;
self
}
}

impl<M> OptimAlgorithm<M> for StochasticGD
Expand Down Expand Up @@ -184,13 +201,21 @@ impl<M> OptimAlgorithm<M> for StochasticGD
&inputs.select_rows(&[*i]),
&targets.select_rows(&[*i]));

// Backup previous velocity
let prev_w = delta_w.clone();
// Compute the difference in gradient using Nesterov momentum
delta_w = Vector::new(vec_data) * self.mu + &delta_w * self.alpha;
// Update the parameters
optimizing_val = &optimizing_val -
(&prev_w * (-self.alpha) + &delta_w * (1. + self.alpha));
if self.nesterove_momentum {
// Backup previous velocity
let prev_w = delta_w.clone();
// Compute the difference in gradient using Nesterov momentum
delta_w = Vector::new(vec_data) * self.mu + &delta_w * self.alpha;
// Update the parameters
optimizing_val = &optimizing_val -
(&prev_w * (-self.alpha) + &delta_w * (1. + self.alpha));
} else {
// Compute the difference in gradient using momentum
delta_w = Vector::new(vec_data) * self.mu + &delta_w * self.alpha;
// Update the parameters
optimizing_val = &optimizing_val - &delta_w * self.mu;
}

// Set the end cost (this is only used after the last iteration)
end_cost += cost;
}
Expand Down
17 changes: 16 additions & 1 deletion tests/learning/optim/grad_desc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,22 @@ fn convex_gd_training() {
fn convex_stochastic_gd_training() {
let x_sq = XSqModel { c: 20f64 };

let gd = StochasticGD::new(0.9f64, 0.1f64, 100);
let gd = StochasticGD::new(0.5f64, 1f64, 100);
let test_data = vec![100f64];
let params = gd.optimize(&x_sq,
&test_data[..],
&Matrix::zeros(100, 1),
&Matrix::zeros(100, 1));

assert!(params[0] - 20f64 < 1e-10);
assert!(x_sq.compute_grad(&params, &Matrix::zeros(1, 1), &Matrix::zeros(1, 1)).0 < 1e-10);
}

#[test]
fn convex_stochastic_gd_nesterove_momentum_training() {
let x_sq = XSqModel { c: 20f64 };

let gd = StochasticGD::new(0.9f64, 0.1f64, 100).with_nesterove_momentum();
let test_data = vec![100f64];
let params = gd.optimize(&x_sq,
&test_data[..],
Expand Down