@@ -11,7 +11,7 @@ pub fn logistic_regression(
1111 return None ;
1212 }
1313
14- let num_features = data_points[ 0 ] . 0 . len ( ) ;
14+ let num_features = data_points[ 0 ] . 0 . len ( ) + 1 ;
1515 let mut params = vec ! [ 0.0 ; num_features] ;
1616
1717 let derivative_fn = |params : & [ f64 ] | derivative ( params, & data_points) ;
@@ -26,11 +26,17 @@ fn derivative(params: &[f64], data_points: &[(Vec<f64>, f64)]) -> Vec<f64> {
2626 let mut gradients = vec ! [ 0.0 ; num_features] ;
2727
2828 for ( features, y_i) in data_points {
29- let z = params. iter ( ) . zip ( features) . map ( |( p, x) | p * x) . sum :: < f64 > ( ) ;
29+ let z = params[ 0 ]
30+ + params[ 1 ..]
31+ . iter ( )
32+ . zip ( features)
33+ . map ( |( p, x) | p * x)
34+ . sum :: < f64 > ( ) ;
3035 let prediction = 1.0 / ( 1.0 + E . powf ( -z) ) ;
3136
37+ gradients[ 0 ] += prediction - y_i;
3238 for ( i, x_i) in features. iter ( ) . enumerate ( ) {
33- gradients[ i] += ( prediction - y_i) * x_i;
39+ gradients[ i + 1 ] += ( prediction - y_i) * x_i;
3440 }
3541 }
3642
@@ -42,21 +48,45 @@ mod test {
4248 use super :: * ;
4349
4450 #[ test]
45- fn test_logistic_regression ( ) {
51+ fn test_logistic_regression_simple ( ) {
4652 let data = vec ! [
47- ( vec![ 0.0 , 0.0 ] , 0.0 ) ,
48- ( vec![ 1.0 , 1.0 ] , 1.0 ) ,
49- ( vec![ 2.0 , 2.0 ] , 1.0 ) ,
53+ ( vec![ 0.0 ] , 0.0 ) ,
54+ ( vec![ 1.0 ] , 0.0 ) ,
55+ ( vec![ 2.0 ] , 0.0 ) ,
56+ ( vec![ 3.0 ] , 1.0 ) ,
57+ ( vec![ 4.0 ] , 1.0 ) ,
58+ ( vec![ 5.0 ] , 1.0 ) ,
5059 ] ;
51- let result = logistic_regression ( data, 10000 , 0.1 ) ;
60+
61+ let result = logistic_regression ( data, 10000 , 0.05 ) ;
5262 assert ! ( result. is_some( ) ) ;
63+
64+ let params = result. unwrap ( ) ;
65+ assert ! ( ( params[ 0 ] + 17.65 ) . abs( ) < 1.0 ) ;
66+ assert ! ( ( params[ 1 ] - 7.13 ) . abs( ) < 1.0 ) ;
67+ }
68+
69+ #[ test]
70+ fn test_logistic_regression_extreme_data ( ) {
71+ let data = vec ! [
72+ ( vec![ -100.0 ] , 0.0 ) ,
73+ ( vec![ -10.0 ] , 0.0 ) ,
74+ ( vec![ 0.0 ] , 0.0 ) ,
75+ ( vec![ 10.0 ] , 1.0 ) ,
76+ ( vec![ 100.0 ] , 1.0 ) ,
77+ ] ;
78+
79+ let result = logistic_regression ( data, 10000 , 0.05 ) ;
80+ assert ! ( result. is_some( ) ) ;
81+
5382 let params = result. unwrap ( ) ;
54- assert ! ( ( params[ 0 ] - 6.902976808251308 ) . abs( ) < 1e-6 ) ;
55- assert ! ( ( params[ 1 ] - 2000.4659358334482 ) . abs( ) < 1e-6 ) ;
83+ assert ! ( ( params[ 0 ] + 6.20 ) . abs( ) < 1.0 ) ;
84+ assert ! ( ( params[ 1 ] - 5.5 ) . abs( ) < 1.0 ) ;
5685 }
5786
5887 #[ test]
59- fn test_empty_list_logistic_regression ( ) {
60- assert_eq ! ( logistic_regression( vec![ ] , 10000 , 0.1 ) , None ) ;
88+ fn test_logistic_regression_no_data ( ) {
89+ let result = logistic_regression ( vec ! [ ] , 5000 , 0.1 ) ;
90+ assert_eq ! ( result, None ) ;
6191 }
6292}
0 commit comments