@@ -41,6 +41,7 @@ class Term
4141    VectorXd y_discretized;
4242    VectorXd errors_initial;
4343    double  error_initial;
44+     std::vector<size_t > observations_in_bins;
4445
4546    // methods
4647    void  calculate_error_where_given_terms_are_zero (const  VectorXd &y, const  VectorXd &sample_weight);
@@ -411,6 +412,13 @@ void Term::setup_bins()
411412        }
412413        bins_split_points_left.shrink_to_fit ();
413414        bins_split_points_right.shrink_to_fit ();
415+ 
416+         // observations in bins
417+         observations_in_bins.reserve (bins_start_index.size ());
418+         for  (size_t  i = 0 ; i < bins_start_index.size (); ++i)
419+         {
420+             observations_in_bins.push_back (bins_end_index[i]-bins_start_index[i]+1 );
421+         }
414422    }
415423}
416424
@@ -421,22 +429,30 @@ void Term::discretize_data_by_bin()
421429        values_discretized.resize (bins_start_index.size ());
422430        for  (size_t  i = 0 ; i < bins_start_index.size (); ++i)
423431        {
424-             values_discretized[i]=sorted_vectors.values_sorted .block (bins_start_index[i],0 ,bins_end_index [i]-bins_start_index[i]+ 1 ,1 ).mean ();
432+             values_discretized[i]=sorted_vectors.values_sorted .block (bins_start_index[i],0 ,observations_in_bins [i],1 ).mean ();
425433        }
426434
427-         if (sorted_vectors.sample_weight_sorted .size ()>0 )
435+         sample_weight_discretized.resize (bins_start_index.size ());
436+         bool  sample_weights_were_provided_by_user{sorted_vectors.sample_weight_sorted .size ()>0 };
437+         if (sample_weights_were_provided_by_user)
428438        {
429-             sample_weight_discretized.resize (bins_start_index.size ());
430439            for  (size_t  i = 0 ; i < bins_start_index.size (); ++i)
431440            {
432-                 sample_weight_discretized[i]=sorted_vectors.sample_weight_sorted .block (bins_start_index[i],0 ,bins_end_index[i]-bins_start_index[i]+1 ,1 ).mean ();
441+                 sample_weight_discretized[i]=sorted_vectors.sample_weight_sorted .block (bins_start_index[i],0 ,observations_in_bins[i],1 ).sum ();
442+             }
443+         }
444+         else 
445+         {
446+             for  (size_t  i = 0 ; i < bins_start_index.size (); ++i)
447+             {
448+                 sample_weight_discretized[i]=static_cast <double >(observations_in_bins[i]);
433449            }
434450        }
435451    }
436452    y_discretized.resize (bins_start_index.size ());
437453    for  (size_t  i = 0 ; i < bins_start_index.size (); ++i)
438454    {
439-         y_discretized[i]=sorted_vectors.y_sorted .block (bins_start_index[i],0 ,bins_end_index [i]-bins_start_index[i]+ 1 ,1 ).mean ();
455+         y_discretized[i]=sorted_vectors.y_sorted .block (bins_start_index[i],0 ,observations_in_bins [i],1 ).mean ();
440456    }
441457
442458    max_index_discretized=calculate_max_index_in_vector (values_discretized);
@@ -516,16 +532,8 @@ void Term::calculate_coefficient_and_error_on_discretized_data(bool direction_ri
516532    double  xwy{0 };
517533    for  (size_t  i = index_start; i <= index_end; ++i)
518534    {
519-         if (sample_weight_discretized.size ()>0 )
520-         {
521-             xwx+=values_sorted[i]*values_sorted[i]*sample_weight_discretized[i];
522-             xwy+=values_sorted[i]*y_discretized[i]*sample_weight_discretized[i];
523-         }
524-         else 
525-         {
526-             xwx+=values_sorted[i]*values_sorted[i];
527-             xwy+=values_sorted[i]*y_discretized[i];
528-         }
535+         xwx+=values_sorted[i]*values_sorted[i]*sample_weight_discretized[i];
536+         xwy+=values_sorted[i]*y_discretized[i]*sample_weight_discretized[i];
529537    }
530538    if (xwx!=0 )
531539    {
@@ -595,6 +603,7 @@ void Term::cleanup_after_fit()
595603    bins_end_index.clear ();
596604    bins_split_points_left.clear ();
597605    bins_split_points_right.clear ();
606+     observations_in_bins.clear ();
598607    values_discretized.resize (0 );
599608    sample_weight_discretized.resize (0 );
600609}
0 commit comments