Skip to content

Commit 54eae0e

Browse files
Merge pull request #5 from ottenbreit-data-science/bins3
Improved binning methodology
2 parents 98603f8 + bf8a395 commit 54eae0e

File tree

4 files changed

+28
-19
lines changed

4 files changed

+28
-19
lines changed

cpp/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ int main()
4646
//Saving results
4747
save_data("output.csv",predictions);
4848
std::cout<<"min validation_error "<<model.validation_error_steps.minCoeff()<<"\n\n";
49-
std::cout<<check_if_approximately_equal(model.validation_error_steps.minCoeff(),6.95696,0.00001)<<"\n";
49+
std::cout<<check_if_approximately_equal(model.validation_error_steps.minCoeff(),7.02559,0.00001)<<"\n";
5050

5151
std::cout<<"mean prediction "<<predictions.mean()<<"\n\n";
52-
std::cout<<check_if_approximately_equal(predictions.mean(),23.8927,0.0001)<<"\n";
52+
std::cout<<check_if_approximately_equal(predictions.mean(),23.9213,0.0001)<<"\n";
5353

5454
std::cout<<"best_m: "<<model.m<<"\n";
5555

cpp/term.h

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Term
4141
VectorXd y_discretized;
4242
VectorXd errors_initial;
4343
double error_initial;
44+
std::vector<size_t> observations_in_bins;
4445

4546
//methods
4647
void calculate_error_where_given_terms_are_zero(const VectorXd &y, const VectorXd &sample_weight);
@@ -411,6 +412,13 @@ void Term::setup_bins()
411412
}
412413
bins_split_points_left.shrink_to_fit();
413414
bins_split_points_right.shrink_to_fit();
415+
416+
//observations in bins
417+
observations_in_bins.reserve(bins_start_index.size());
418+
for (size_t i = 0; i < bins_start_index.size(); ++i)
419+
{
420+
observations_in_bins.push_back(bins_end_index[i]-bins_start_index[i]+1);
421+
}
414422
}
415423
}
416424

@@ -421,22 +429,30 @@ void Term::discretize_data_by_bin()
421429
values_discretized.resize(bins_start_index.size());
422430
for (size_t i = 0; i < bins_start_index.size(); ++i)
423431
{
424-
values_discretized[i]=sorted_vectors.values_sorted.block(bins_start_index[i],0,bins_end_index[i]-bins_start_index[i]+1,1).mean();
432+
values_discretized[i]=sorted_vectors.values_sorted.block(bins_start_index[i],0,observations_in_bins[i],1).mean();
425433
}
426434

427-
if(sorted_vectors.sample_weight_sorted.size()>0)
435+
sample_weight_discretized.resize(bins_start_index.size());
436+
bool sample_weights_were_provided_by_user{sorted_vectors.sample_weight_sorted.size()>0};
437+
if(sample_weights_were_provided_by_user)
428438
{
429-
sample_weight_discretized.resize(bins_start_index.size());
430439
for (size_t i = 0; i < bins_start_index.size(); ++i)
431440
{
432-
sample_weight_discretized[i]=sorted_vectors.sample_weight_sorted.block(bins_start_index[i],0,bins_end_index[i]-bins_start_index[i]+1,1).mean();
441+
sample_weight_discretized[i]=sorted_vectors.sample_weight_sorted.block(bins_start_index[i],0,observations_in_bins[i],1).sum();
442+
}
443+
}
444+
else
445+
{
446+
for (size_t i = 0; i < bins_start_index.size(); ++i)
447+
{
448+
sample_weight_discretized[i]=static_cast<double>(observations_in_bins[i]);
433449
}
434450
}
435451
}
436452
y_discretized.resize(bins_start_index.size());
437453
for (size_t i = 0; i < bins_start_index.size(); ++i)
438454
{
439-
y_discretized[i]=sorted_vectors.y_sorted.block(bins_start_index[i],0,bins_end_index[i]-bins_start_index[i]+1,1).mean();
455+
y_discretized[i]=sorted_vectors.y_sorted.block(bins_start_index[i],0,observations_in_bins[i],1).mean();
440456
}
441457

442458
max_index_discretized=calculate_max_index_in_vector(values_discretized);
@@ -516,16 +532,8 @@ void Term::calculate_coefficient_and_error_on_discretized_data(bool direction_ri
516532
double xwy{0};
517533
for (size_t i = index_start; i <= index_end; ++i)
518534
{
519-
if(sample_weight_discretized.size()>0)
520-
{
521-
xwx+=values_sorted[i]*values_sorted[i]*sample_weight_discretized[i];
522-
xwy+=values_sorted[i]*y_discretized[i]*sample_weight_discretized[i];
523-
}
524-
else
525-
{
526-
xwx+=values_sorted[i]*values_sorted[i];
527-
xwy+=values_sorted[i]*y_discretized[i];
528-
}
535+
xwx+=values_sorted[i]*values_sorted[i]*sample_weight_discretized[i];
536+
xwy+=values_sorted[i]*y_discretized[i]*sample_weight_discretized[i];
529537
}
530538
if(xwx!=0)
531539
{
@@ -595,6 +603,7 @@ void Term::cleanup_after_fit()
595603
bins_end_index.clear();
596604
bins_split_points_left.clear();
597605
bins_split_points_right.clear();
606+
observations_in_bins.clear();
598607
values_discretized.resize(0);
599608
sample_weight_discretized.resize(0);
600609
}

cpp/test ALRRegressor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ int main()
1616
model.m=100;
1717
model.v=1.0;
1818
model.bins=10;
19-
model.n_jobs=0;
19+
model.n_jobs=1;
2020
model.loss_function_mse=true;
2121
model.verbosity=3;
2222
model.max_interaction_level=100;

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setuptools.setup(
1717
name='aplr',
18-
version='1.0.6',
18+
version='1.0.7',
1919
description='Automatic Piecewise Linear Regression',
2020
ext_modules=[sfc_module],
2121
author="Mathias von Ottenbreit",

0 commit comments

Comments
 (0)