Skip to content

Commit ea1c70d

Browse files
bugfix
1 parent b707277 commit ea1c70d

File tree

4 files changed

+26
-44
lines changed

4 files changed

+26
-44
lines changed

API_REFERENCE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Limits 1) the number of terms already in the model that can be considered as int
5353
Specifies the variance power for the "tweedie" ***family***.
5454

5555
#### validation_tuning_metric (default = "default")
56-
Specifies which tuning metric to use for validating the model. Available options are "default" (using the same methodology as when calculating the training error), "mse", "mae" and "rankability". The default is often a choice that fits well with respect to the ***family*** chosen. However, if you want to use ***family*** as a tuning parameter then the default is not suitable. "rankability" uses a methodology similar to the one described in https://towardsdatascience.com/how-to-calculate-roc-auc-score-for-regression-models-c0be4fdf76bb
56+
Specifies which metric to use for validating the model and tuning ***m***. Available options are "default" (using the same methodology as when calculating the training error), "mse", "mae" and "rankability". The default is often a choice that fits well with respect to the ***family*** chosen. However, if you want to use ***family*** or ***tweedie_power*** as tuning parameters then the default is not suitable. "rankability" uses a methodology similar to the one described in https://towardsdatascience.com/how-to-calculate-roc-auc-score-for-regression-models-c0be4fdf76bb
5757

5858
## Method: fit(X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np.empty(0), X_names:List[str]=[], validation_set_indexes:List[int]=[], prioritized_predictors_indexes:List[int]=[], monotonic_constraints:List[int]=[])
5959

cpp/APLRRegressor.h

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class APLRRegressor
4646
double scaling_factor_for_log_link_function;
4747
std::vector<size_t> predictor_indexes;
4848
std::vector<size_t> prioritized_predictors_indexes;
49-
std::vector<int> monotonic_constraints; //Make this VectorXi and validate for nan/inf input
49+
std::vector<int> monotonic_constraints;
5050

5151
//Methods
5252
void validate_input_to_fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight,const std::vector<std::string> &X_names,
@@ -78,6 +78,7 @@ class APLRRegressor
7878
void update_gradient_and_errors();
7979
void add_new_term(size_t boosting_step);
8080
void calculate_and_validate_validation_error(size_t boosting_step);
81+
void calculate_validation_error(size_t boosting_step, const VectorXd &predictions);
8182
void update_term_eligibility();
8283
void print_summary_after_boosting_step(size_t boosting_step);
8384
void update_coefficients_for_all_steps();
@@ -962,50 +963,14 @@ void APLRRegressor::add_new_term(size_t boosting_step)
962963

963964
void APLRRegressor::calculate_and_validate_validation_error(size_t boosting_step)
964965
{
965-
VectorXd rescaled_predictions_current_validation(0);
966-
bool link_function_is_log{link_function=="log"};
967-
if(link_function_is_log)
968-
{
969-
rescaled_predictions_current_validation = predictions_current_validation / scaling_factor_for_log_link_function;
970-
}
971-
972-
bool using_default{validation_tuning_metric=="default"};
973-
bool using_mse{validation_tuning_metric=="mse"};
974-
bool using_mae{validation_tuning_metric=="mae"};
975-
bool using_rankability{validation_tuning_metric=="rankability"};
976-
if(using_default)
977-
{
978-
if(link_function_is_log)
979-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,rescaled_predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
980-
else
981-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
982-
}
983-
else if(using_mse)
984-
{
985-
if(link_function_is_log)
986-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,rescaled_predictions_current_validation,sample_weight_validation),sample_weight_validation);
987-
else
988-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions_current_validation,sample_weight_validation),sample_weight_validation);
989-
}
990-
else if(using_mae)
991-
{
992-
if(link_function_is_log)
993-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_absolute_errors(y_validation,rescaled_predictions_current_validation,sample_weight_validation),sample_weight_validation);
994-
else
995-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_absolute_errors(y_validation,predictions_current_validation,sample_weight_validation),sample_weight_validation);
996-
}
997-
else if(using_rankability)
966+
if(link_function=="log")
998967
{
999-
if(link_function_is_log)
1000-
validation_error_steps[boosting_step]=-calculate_rankability(y_validation,rescaled_predictions_current_validation,sample_weight_validation,random_state);
1001-
else
1002-
validation_error_steps[boosting_step]=-calculate_rankability(y_validation,predictions_current_validation,sample_weight_validation,random_state);
968+
VectorXd rescaled_predictions_current_validation{predictions_current_validation / scaling_factor_for_log_link_function};
969+
calculate_validation_error(boosting_step, rescaled_predictions_current_validation);
1003970
}
1004971
else
1005-
{
1006-
throw std::runtime_error(validation_tuning_metric + " is an invalid validation_tuning_metric.");
1007-
}
1008-
972+
calculate_validation_error(boosting_step, predictions_current_validation);
973+
1009974
bool validation_error_is_invalid{std::isinf(validation_error_steps[boosting_step])};
1010975
if(validation_error_is_invalid)
1011976
{
@@ -1015,6 +980,20 @@ void APLRRegressor::calculate_and_validate_validation_error(size_t boosting_step
1015980
}
1016981
}
1017982

983+
void APLRRegressor::calculate_validation_error(size_t boosting_step, const VectorXd &predictions)
984+
{
985+
if(validation_tuning_metric=="default")
986+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation,family,tweedie_power),sample_weight_validation);
987+
else if(validation_tuning_metric=="mse")
988+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation),sample_weight_validation);
989+
else if(validation_tuning_metric=="mae")
990+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_absolute_errors(y_validation,predictions,sample_weight_validation),sample_weight_validation);
991+
else if(validation_tuning_metric=="rankability")
992+
validation_error_steps[boosting_step]=-calculate_rankability(y_validation,predictions,sample_weight_validation,random_state);
993+
else
994+
throw std::runtime_error(validation_tuning_metric + " is an invalid validation_tuning_metric.");
995+
}
996+
1018997
void APLRRegressor::update_term_eligibility()
1019998
{
1020999
number_of_eligible_terms=terms_eligible_current.size();

cpp/functions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,9 @@ double calculate_rankability(const VectorXd &y_true, const VectorXd &y_pred, con
353353
}
354354
}
355355
double rankability{num_ranked_correctly/num_pairs};
356+
bool rankability_is_invalid{!std::isfinite(rankability)};
357+
if(rankability_is_invalid)
358+
rankability=0.5;
356359

357360
return rankability;
358361
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setuptools.setup(
1717
name='aplr',
18-
version='2.0.0',
18+
version='2.0.1',
1919
description='Automatic Piecewise Linear Regression',
2020
ext_modules=[sfc_module],
2121
author="Mathias von Ottenbreit",

0 commit comments

Comments
 (0)