Skip to content

Commit da1b10d

Browse files
committed
Fixing the test cases and minor adjustment to the algorithm
1 parent 21850fe commit da1b10d

File tree

1 file changed

+42
-12
lines changed

1 file changed

+42
-12
lines changed

src/machine_learning/logistic_regression.rs

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub fn logistic_regression(
1111
return None;
1212
}
1313

14-
let num_features = data_points[0].0.len();
14+
let num_features = data_points[0].0.len() + 1;
1515
let mut params = vec![0.0; num_features];
1616

1717
let derivative_fn = |params: &[f64]| derivative(params, &data_points);
@@ -26,11 +26,17 @@ fn derivative(params: &[f64], data_points: &[(Vec<f64>, f64)]) -> Vec<f64> {
2626
let mut gradients = vec![0.0; num_features];
2727

2828
for (features, y_i) in data_points {
29-
let z = params.iter().zip(features).map(|(p, x)| p * x).sum::<f64>();
29+
let z = params[0]
30+
+ params[1..]
31+
.iter()
32+
.zip(features)
33+
.map(|(p, x)| p * x)
34+
.sum::<f64>();
3035
let prediction = 1.0 / (1.0 + E.powf(-z));
3136

37+
gradients[0] += prediction - y_i;
3238
for (i, x_i) in features.iter().enumerate() {
33-
gradients[i] += (prediction - y_i) * x_i;
39+
gradients[i + 1] += (prediction - y_i) * x_i;
3440
}
3541
}
3642

@@ -42,21 +48,45 @@ mod test {
4248
use super::*;
4349

4450
#[test]
45-
fn test_logistic_regression() {
51+
fn test_logistic_regression_simple() {
4652
let data = vec![
47-
(vec![0.0, 0.0], 0.0),
48-
(vec![1.0, 1.0], 1.0),
49-
(vec![2.0, 2.0], 1.0),
53+
(vec![0.0], 0.0),
54+
(vec![1.0], 0.0),
55+
(vec![2.0], 0.0),
56+
(vec![3.0], 1.0),
57+
(vec![4.0], 1.0),
58+
(vec![5.0], 1.0),
5059
];
51-
let result = logistic_regression(data, 10000, 0.1);
60+
61+
let result = logistic_regression(data, 10000, 0.05);
5262
assert!(result.is_some());
63+
64+
let params = result.unwrap();
65+
assert!((params[0] + 17.65).abs() < 1.0);
66+
assert!((params[1] - 7.13).abs() < 1.0);
67+
}
68+
69+
#[test]
70+
fn test_logistic_regression_extreme_data() {
71+
let data = vec![
72+
(vec![-100.0], 0.0),
73+
(vec![-10.0], 0.0),
74+
(vec![0.0], 0.0),
75+
(vec![10.0], 1.0),
76+
(vec![100.0], 1.0),
77+
];
78+
79+
let result = logistic_regression(data, 10000, 0.05);
80+
assert!(result.is_some());
81+
5382
let params = result.unwrap();
54-
assert!((params[0] - 6.902976808251308).abs() < 1e-6);
55-
assert!((params[1] - 2000.4659358334482).abs() < 1e-6);
83+
assert!((params[0] + 6.20).abs() < 1.0);
84+
assert!((params[1] - 5.5).abs() < 1.0);
5685
}
5786

5887
#[test]
59-
fn test_empty_list_logistic_regression() {
60-
assert_eq!(logistic_regression(vec![], 10000, 0.1), None);
88+
fn test_logistic_regression_no_data() {
89+
let result = logistic_regression(vec![], 5000, 0.1);
90+
assert_eq!(result, None);
6191
}
6292
}

0 commit comments

Comments
 (0)