@@ -11,7 +11,7 @@ pub fn logistic_regression(
1111 return None ;
1212 }
1313
14- let num_features = data_points[ 0 ] . 0 . len ( ) ;
14+ let num_features = data_points[ 0 ] . 0 . len ( ) + 1 ;
1515 let mut params = vec ! [ 0.0 ; num_features] ;
1616
1717 let derivative_fn = |params : & [ f64 ] | derivative ( params, & data_points) ;
@@ -26,11 +26,12 @@ fn derivative(params: &[f64], data_points: &[(Vec<f64>, f64)]) -> Vec<f64> {
2626 let mut gradients = vec ! [ 0.0 ; num_features] ;
2727
2828 for ( features, y_i) in data_points {
29- let z = params. iter ( ) . zip ( features) . map ( |( p, x) | p * x) . sum :: < f64 > ( ) ;
29+ let z = params[ 0 ] + params [ 1 .. ] . iter ( ) . zip ( features) . map ( |( p, x) | p * x) . sum :: < f64 > ( ) ;
3030 let prediction = 1.0 / ( 1.0 + E . powf ( -z) ) ;
3131
32+ gradients[ 0 ] += prediction - y_i;
3233 for ( i, x_i) in features. iter ( ) . enumerate ( ) {
33- gradients[ i] += ( prediction - y_i) * x_i;
34+ gradients[ i + 1 ] += ( prediction - y_i) * x_i;
3435 }
3536 }
3637
@@ -42,21 +43,46 @@ mod test {
4243 use super :: * ;
4344
4445 #[ test]
45- fn test_logistic_regression ( ) {
46+ fn test_logistic_regression_simple ( ) {
4647 let data = vec ! [
47- ( vec![ 0.0 , 0.0 ] , 0.0 ) ,
48- ( vec![ 1.0 , 1.0 ] , 1.0 ) ,
49- ( vec![ 2.0 , 2.0 ] , 1.0 ) ,
48+ ( vec![ 0.0 ] , 0.0 ) ,
49+ ( vec![ 1.0 ] , 0.0 ) ,
50+ ( vec![ 2.0 ] , 0.0 ) ,
51+ ( vec![ 3.0 ] , 1.0 ) ,
52+ ( vec![ 4.0 ] , 1.0 ) ,
53+ ( vec![ 5.0 ] , 1.0 ) ,
5054 ] ;
51- let result = logistic_regression ( data, 10000 , 0.1 ) ;
55+
56+ let result = logistic_regression ( data, 10000 , 0.05 ) ;
57+ assert ! ( result. is_some( ) ) ;
58+
59+ let params = result. unwrap ( ) ;
60+ assert ! ( ( params[ 0 ] + 17.65 ) . abs( ) < 1.0 ) ;
61+ assert ! ( ( params[ 1 ] - 7.13 ) . abs( ) < 1.0 ) ;
62+ }
63+
64+ #[ test]
65+ fn test_logistic_regression_extreme_data ( ) {
66+ let data = vec ! [
67+ ( vec![ -100.0 ] , 0.0 ) ,
68+ ( vec![ -10.0 ] , 0.0 ) ,
69+ ( vec![ 0.0 ] , 0.0 ) ,
70+ ( vec![ 10.0 ] , 1.0 ) ,
71+ ( vec![ 100.0 ] , 1.0 ) ,
72+ ] ;
73+
74+ let result = logistic_regression ( data, 10000 , 0.05 ) ;
5275 assert ! ( result. is_some( ) ) ;
76+
5377 let params = result. unwrap ( ) ;
54- assert ! ( ( params[ 0 ] - 6.902976808251308 ) . abs( ) < 1e-6 ) ;
55- assert ! ( ( params[ 1 ] - 2000.4659358334482 ) . abs( ) < 1e-6 ) ;
78+ assert ! ( ( params[ 0 ] + 6.20 ) . abs( ) < 1.0 ) ;
79+ assert ! ( ( params[ 1 ] - 5.5 ) . abs( ) < 1.0 ) ;
5680 }
5781
5882 #[ test]
59- fn test_empty_list_logistic_regression ( ) {
60- assert_eq ! ( logistic_regression( vec![ ] , 10000 , 0.1 ) , None ) ;
83+ fn test_logistic_regression_no_data ( ) {
84+ let result = logistic_regression ( vec ! [ ] , 5000 , 0.1 ) ;
85+ assert_eq ! ( result, None ) ;
6186 }
6287}
88+
0 commit comments