Skip to content

Commit 7e5eb5c

Browse files
authored
Merge pull request #306 from yfnaji/lasso
Lasso Regression
2 parents ec6097d + 48d9a0d commit 7e5eb5c

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed

crates/RustQuant_ml/src/lasso.rs

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2+
// RustQuant: A Rust library for quantitative finance tools.
3+
// Copyright (C) 2023 https://github.com/avhz
4+
// Dual licensed under Apache 2.0 and MIT.
5+
// See:
6+
// - LICENSE-APACHE.md
7+
// - LICENSE-MIT.md
8+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9+
10+
//! Module for Lasso algorithms.
11+
12+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
13+
// IMPORTS
14+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
15+
16+
use nalgebra::{DMatrix, DVector};
17+
use RustQuant_error::RustQuantError;
18+
19+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
20+
// STRUCTS, ENUMS, AND TRAITS
21+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22+
23+
/// Struct to hold the input data for a Lasso regression.
24+
#[allow(clippy::module_name_repetitions)]
25+
#[derive(Clone, Debug)]
26+
pub struct LassoInput<T> {
27+
/// The features matrix.
28+
pub x: DMatrix<T>,
29+
/// The output data vector, also known as the response vector.
30+
pub y: DVector<T>,
31+
/// The regularization parameter.
32+
pub lambda: T,
33+
/// Include the intercept.
34+
pub fit_intercept: bool,
35+
/// The maximum number of iterations for training.
36+
pub max_iter: usize,
37+
/// The tolerance for the convergence.
38+
pub tolerance: T,
39+
}
40+
41+
/// Struct to hold the output data for lasso.
42+
#[allow(clippy::module_name_repetitions)]
43+
#[derive(Clone, Debug)]
44+
pub struct LassoOutput<T> {
45+
/// The intercept of the lasso regression,
46+
pub intercept: T,
47+
/// The coefficients of the lasso regression,
48+
pub coefficients: DVector<T>,
49+
}
50+
51+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
52+
// IMPLEMENTATIONS
53+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
54+
55+
impl LassoInput<f64> {
56+
/// Create a new `LassoInput` struct.
57+
#[must_use]
58+
pub fn new(
59+
x: DMatrix<f64>,
60+
y: DVector<f64>,
61+
lambda: f64,
62+
fit_intercept: bool,
63+
max_iter: usize,
64+
tolerance: f64,
65+
) -> Self {
66+
Self { x, y, lambda, fit_intercept, max_iter, tolerance }
67+
}
68+
69+
/// Fits a Lasso regression to the input data.
70+
/// Returns the intercept and coefficients.
71+
/// The intercept is the first value of the coefficients.
72+
pub fn fit(&self) -> Result<LassoOutput<f64>, RustQuantError> {
73+
let n_cols = self.x.ncols();
74+
let n_rows = self.x.nrows() as f64;
75+
let mut features_matrix = self.x.clone();
76+
let mut residuals = self.y.clone();
77+
let feature_means = DVector::from_iterator(
78+
self.x.ncols(),
79+
(0..self.x.ncols()).map(|j| self.x.column(j).mean())
80+
);
81+
82+
if self.fit_intercept {
83+
84+
features_matrix = self.x.clone();
85+
for j in 0..self.x.ncols() {
86+
let mean = feature_means[j];
87+
for i in 0..self.x.nrows() {
88+
features_matrix[(i, j)] -= mean;
89+
}
90+
}
91+
residuals -= DVector::from_element(self.x.nrows(), self.y.mean());
92+
}
93+
94+
let mut coefficients = DVector::<f64>::zeros(n_cols);
95+
96+
for _ in 0..self.max_iter {
97+
let mut max_delta: f64 = 0.0;
98+
for j in 0..n_cols {
99+
100+
let feature_vals_col_j = features_matrix.column(j);
101+
let col_norm: f64 = feature_vals_col_j.dot(&feature_vals_col_j);
102+
let rho: f64 = (residuals.dot(&feature_vals_col_j) + coefficients[j] * col_norm) / n_rows;
103+
104+
let new_coefficient_j: f64 = if rho < -self.lambda {
105+
(rho + self.lambda) / (col_norm / n_rows)
106+
} else if rho > self.lambda {
107+
(rho - self.lambda) / (col_norm / n_rows)
108+
} else {
109+
0.0
110+
};
111+
112+
let delta: f64 = new_coefficient_j - coefficients[j];
113+
if delta.abs() > 0.0 {
114+
residuals -= &feature_vals_col_j * delta;
115+
}
116+
coefficients[j] = new_coefficient_j;
117+
max_delta = max_delta.max(delta.abs());
118+
}
119+
120+
if max_delta < self.tolerance {
121+
break;
122+
}
123+
}
124+
125+
let intercept: f64 = if self.fit_intercept {
126+
self.y.mean() - feature_means.dot(&coefficients)
127+
} else {
128+
0.0
129+
};
130+
coefficients = coefficients.insert_row(0, intercept);
131+
132+
Ok(LassoOutput {
133+
intercept,
134+
coefficients,
135+
})
136+
}
137+
}
138+
139+
impl LassoOutput<f64> {
140+
/// Predicts the output for the given input data.
141+
pub fn predict(&self, input: DMatrix<f64>) -> Result<DVector<f64>, RustQuantError> {
142+
let intercept = DVector::from_element(
143+
input.nrows(),
144+
self.intercept
145+
);
146+
let coefficients = self.coefficients.clone().remove_row(0);
147+
let predictions = input * coefficients + intercept;
148+
Ok(predictions)
149+
}
150+
}
151+
152+
153+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
154+
// UNIT TESTS
155+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
156+
157+
#[cfg(test)]
158+
mod tests_lasso_regression {
159+
use super::*;
160+
use RustQuant_utils::assert_approx_equal;
161+
162+
struct DataForTests {
163+
training_set: DMatrix<f64>,
164+
testing_set: DMatrix<f64>,
165+
response: DVector<f64>,
166+
}
167+
168+
fn setup_test() -> DataForTests {
169+
DataForTests {
170+
training_set: DMatrix::from_row_slice(
171+
4,
172+
3,
173+
&[
174+
-0.083_784_355, -0.633_485_70, -0.399_266_60,
175+
-0.982_943_745, 1.090_797_46, -0.468_123_05,
176+
-1.875_067_321, -0.913_727_27, 0.326_962_08,
177+
-0.186_144_661, 1.001_639_71, -0.412_746_90],
178+
),
179+
180+
testing_set: DMatrix::from_row_slice(
181+
4,
182+
3,
183+
&[
184+
0.562_036_47, 0.595_846_45, -0.411_653_01,
185+
0.663_358_26, 0.452_091_83, -0.294_327_15,
186+
-0.602_897_28, 0.896_743_96, 1.218_573_96,
187+
0.698_377_69, 0.572_216_51, 0.244_111_43],
188+
),
189+
190+
response: DVector::from_row_slice(
191+
&[
192+
-0.445_151_96,
193+
-1.847_803_64,
194+
-0.628_825_31,
195+
-0.861_080_69
196+
]
197+
),
198+
}
199+
}
200+
201+
#[test]
202+
fn test_lasso_without_intercept() -> Result<(), RustQuantError> {
203+
204+
let data: DataForTests = setup_test();
205+
206+
let input: LassoInput<f64> = LassoInput {
207+
x: data.training_set,
208+
y: data.response,
209+
lambda: 0.01,
210+
fit_intercept: false,
211+
max_iter: 1000,
212+
tolerance: 1e-4,
213+
};
214+
215+
let output: LassoOutput<f64> = input.fit()?;
216+
let predictions = output.predict(data.testing_set)?;
217+
218+
for (i, coefficient) in output.coefficients.iter().enumerate() {
219+
assert_approx_equal!(
220+
coefficient,
221+
&[
222+
0.0,
223+
0.743_965_706_491_596_7,
224+
-0.304_713_846_510_641_43,
225+
1.355_162_653_724_116_22,
226+
][i],
227+
f64::EPSILON
228+
);
229+
}
230+
231+
for (i, pred) in predictions.iter().enumerate() {
232+
assert_approx_equal!(
233+
pred,
234+
&[
235+
-0.321_283_589_676_737_6,
236+
-0.04310400559445471,
237+
0.9295807191488583,
238+
0.6760174510230131
239+
][i],
240+
f64::EPSILON
241+
);
242+
}
243+
Ok(())
244+
}
245+
246+
#[test]
247+
fn test_lasso_with_intercept() -> Result<(), RustQuantError> {
248+
249+
let data: DataForTests = setup_test();
250+
251+
let input: LassoInput<f64> = LassoInput {
252+
x: data.training_set,
253+
y: data.response,
254+
lambda: 0.01,
255+
fit_intercept: true,
256+
max_iter: 1000,
257+
tolerance: 1e-4,
258+
};
259+
260+
let output: LassoOutput<f64> = input.fit()?;
261+
let predictions = output.predict(data.testing_set)?;
262+
263+
for (i, coefficient) in output.coefficients.iter().enumerate() {
264+
assert_approx_equal!(
265+
coefficient,
266+
&[
267+
0.009_633_706_736_496_328,
268+
0.750_479_303_541_854_1,
269+
-0.301_997_087_876_784_5,
270+
1.373_605_833_196_545_3,
271+
][i],
272+
f64::EPSILON
273+
);
274+
}
275+
276+
for (i, pred) in predictions.iter().enumerate() {
277+
assert_approx_equal!(
278+
pred,
279+
&[
280+
-0.313_962_423_203_417_3,
281+
-0.033_349_554_520_968_38,
282+
0.960_198_011_081_136_2,
283+
0.696_256_873_679_798_4,
284+
][i],
285+
f64::EPSILON
286+
);
287+
}
288+
Ok(())
289+
}
290+
}

crates/RustQuant_ml/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,7 @@ pub use linear_regression::*;
3333
/// Logistic regression.
3434
pub mod logistic_regression;
3535
pub use logistic_regression::*;
36+
37+
/// lasso regression.
38+
pub mod lasso;
39+
pub use lasso::*;

0 commit comments

Comments
 (0)