Skip to content

Commit b707277

Browse files
validation metrics
1 parent 9d07ab3 commit b707277

13 files changed

+387
-180
lines changed

API_REFERENCE.md

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# APLRRegressor
22

3-
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5, group_size_for_validation_group_mse:int=100)
3+
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5, validation_tuning_metric:str="default")
44

55
### Constructor parameters
66

@@ -50,11 +50,10 @@ Limits 1) the number of terms already in the model that can be considered as int
5050
***0*** does not print progress reports during fitting. ***1*** prints a summary after running the ***fit*** method. ***2*** prints a summary after each boosting step.
5151

5252
#### tweedie_power (default = 1.5)
53-
Species the variance power for the "tweedie" ***family***.
54-
55-
#### group_size_for_validation_group_mse (default = 100)
56-
APLR calculates a tuning metric, mean squared error for groups of observations in the validation set. This metric is provided by the method ***get_validation_group_mse()***. The metric may be useful for tuning ***tweedie_power*** and to some extent ***family*** or ***link_function***. The reasoning behind this is that mean squared error (MSE) is often appropriate for evaluating goodness of fit on approximately normally distributed data. The mean of a group of observations is approximately normally distributed according to the Central Limit Theorem (CLT) if there are enough observations in the group, regardless of how individual observations are distributed. Ideally, ***group_size_for_validation_group_mse*** should be large enough so that the Central Limit Theorem holds (at least 30, but the default of 100 is a safer choice). Also, the number of observations in the validation set should be substantially higher than ***group_size_for_validation_group_mse***.
53+
Specifies the variance power for the "tweedie" ***family***.
5754

55+
#### validation_tuning_metric (default = "default")
56+
Specifies which tuning metric to use for validating the model. Available options are "default" (using the same methodology as when calculating the training error), "mse", "mae" and "rankability". The default is often a choice that fits well with respect to the ***family*** chosen. However, if you want to use ***family*** as a tuning parameter then the default is not suitable. "rankability" uses a methodology similar to the one described in https://towardsdatascience.com/how-to-calculate-roc-auc-score-for-regression-models-c0be4fdf76bb
5857

5958
## Method: fit(X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np.empty(0), X_names:List[str]=[], validation_set_indexes:List[int]=[], prioritized_predictors_indexes:List[int]=[], monotonic_constraints:List[int]=[])
6059

@@ -182,10 +181,9 @@ The index of the term selected. So ***0*** is the first term, ***1*** is the sec
182181
***Returns the number of boosting steps in the model (the value that minimized validation error).***
183182

184183

185-
## Method: get_validation_group_mse()
186-
187-
***Returns mean squared error on grouped data in the validation set.*** See ***group_size_for_validation_group_mse*** for more information.
184+
## Method: get_validation_tuning_metric()
188185

186+
***Returns the validation_tuning_metric used.***
189187

190188
## Method: get_validation_indexes()
191189

aplr/aplr.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class APLRRegressor():
8-
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5, group_size_for_validation_group_mse:int=100):
8+
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5, validation_tuning_metric:str="default"):
99
self.m=m
1010
self.v=v
1111
self.random_state=random_state
@@ -22,7 +22,7 @@ def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaus
2222
self.max_eligible_terms=max_eligible_terms
2323
self.verbosity=verbosity
2424
self.tweedie_power=tweedie_power
25-
self.group_size_for_validation_group_mse=group_size_for_validation_group_mse
25+
self.validation_tuning_metric=validation_tuning_metric
2626

2727
#Creating aplr_cpp and setting parameters
2828
self.APLRRegressor=aplr_cpp.APLRRegressor()
@@ -46,7 +46,7 @@ def __set_params_cpp(self):
4646
self.APLRRegressor.max_eligible_terms=self.max_eligible_terms
4747
self.APLRRegressor.verbosity=self.verbosity
4848
self.APLRRegressor.tweedie_power=self.tweedie_power
49-
self.APLRRegressor.group_size_for_validation_group_mse=self.group_size_for_validation_group_mse
49+
self.APLRRegressor.validation_tuning_metric=self.validation_tuning_metric
5050

5151
def fit(self, X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np.empty(0), X_names:List[str]=[], validation_set_indexes:List[int]=[], prioritized_predictors_indexes:List[int]=[], monotonic_constraints:List[int]=[]):
5252
self.__set_params_cpp()
@@ -90,16 +90,34 @@ def get_intercept_steps(self)->npt.ArrayLike:
9090

9191
def get_m(self)->int:
9292
return self.APLRRegressor.get_m()
93-
94-
def get_validation_group_mse(self)->float:
95-
return self.APLRRegressor.get_validation_group_mse()
9693

94+
def get_validation_tuning_metric(self)->str:
95+
return self.APLRRegressor.get_validation_tuning_metric()
96+
9797
def get_validation_indexes(self)->List[int]:
9898
return self.APLRRegressor.get_validation_indexes()
9999

100100
#For sklearn
101101
def get_params(self, deep=True):
102-
return {"m": self.m, "v": self.v,"random_state":self.random_state,"family":self.family,"link_function":self.link_function,"n_jobs":self.n_jobs,"validation_ratio":self.validation_ratio,"intercept":self.intercept,"bins":self.bins,"max_interaction_level":self.max_interaction_level,"max_interactions":self.max_interactions,"verbosity":self.verbosity,"min_observations_in_split":self.min_observations_in_split,"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,"max_eligible_terms":self.max_eligible_terms,"tweedie_power":self.tweedie_power,"group_size_for_validation_group_mse":self.group_size_for_validation_group_mse}
102+
return {
103+
"m": self.m,
104+
"v": self.v,
105+
"random_state":self.random_state,
106+
"family":self.family,
107+
"link_function":self.link_function,
108+
"n_jobs":self.n_jobs,
109+
"validation_ratio":self.validation_ratio,
110+
"intercept":self.intercept,
111+
"bins":self.bins,
112+
"max_interaction_level":self.max_interaction_level,
113+
"max_interactions":self.max_interactions,
114+
"verbosity":self.verbosity,
115+
"min_observations_in_split":self.min_observations_in_split,
116+
"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,
117+
"max_eligible_terms":self.max_eligible_terms,
118+
"tweedie_power":self.tweedie_power,
119+
"validation_tuning_metric":self.validation_tuning_metric
120+
}
103121

104122
#For sklearn
105123
def set_params(self, **parameters):

cpp/APLRRegressor.h

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ class APLRRegressor
8686
void name_terms(const MatrixXd &X, const std::vector<std::string> &X_names);
8787
void calculate_feature_importance_on_validation_set();
8888
void find_min_and_max_training_predictions_or_responses();
89-
void calculate_validation_group_mse();
9089
void cleanup_after_fit();
9190
void validate_that_model_can_be_used(const MatrixXd &X);
9291
void throw_error_if_family_does_not_exist();
@@ -134,16 +133,15 @@ class APLRRegressor
134133
double tweedie_power;
135134
double min_training_prediction_or_response;
136135
double max_training_prediction_or_response;
137-
double validation_group_mse;
138-
size_t group_size_for_validation_group_mse;
139136
std::vector<size_t> validation_indexes;
137+
std::string validation_tuning_metric;
140138

141139
//Methods
142140
APLRRegressor(size_t m=1000,double v=0.1,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t>::lowest(),std::string family="gaussian",
143141
std::string link_function="identity", size_t n_jobs=0, double validation_ratio=0.2,double intercept=NAN_DOUBLE,
144142
size_t reserved_terms_times_num_x=100, size_t bins=300,size_t verbosity=0,size_t max_interaction_level=1,size_t max_interactions=100000,
145143
size_t min_observations_in_split=20, size_t ineligible_boosting_steps_added=10, size_t max_eligible_terms=5,double tweedie_power=1.5,
146-
size_t group_size_for_validation_group_mse=100);
144+
std::string validation_tuning_metric="default");
147145
APLRRegressor(const APLRRegressor &other);
148146
~APLRRegressor();
149147
void fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight=VectorXd(0),const std::vector<std::string> &X_names={},const std::vector<size_t> &validation_set_indexes={},
@@ -161,22 +159,22 @@ class APLRRegressor
161159
double get_intercept();
162160
VectorXd get_intercept_steps();
163161
size_t get_m();
164-
double get_validation_group_mse();
162+
std::string get_validation_tuning_metric();
165163
std::vector<size_t> get_validation_indexes();
166164
};
167165

168166
//Regular constructor
169167
APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::string family,std::string link_function,size_t n_jobs,
170168
double validation_ratio,double intercept,size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,
171169
size_t max_interactions,size_t min_observations_in_split,size_t ineligible_boosting_steps_added,size_t max_eligible_terms,double tweedie_power,
172-
size_t group_size_for_validation_group_mse):
170+
std::string validation_tuning_metric):
173171
reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{intercept},m{m},v{v},
174172
family{family},link_function{link_function},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
175173
bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
176174
intercept_steps{VectorXd(0)},max_interactions{max_interactions},interactions_eligible{0},validation_error_steps{VectorXd(0)},
177175
min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
178176
max_eligible_terms{max_eligible_terms},number_of_base_terms{0},tweedie_power{tweedie_power},min_training_prediction_or_response{NAN_DOUBLE},
179-
max_training_prediction_or_response{NAN_DOUBLE},validation_group_mse{NAN_DOUBLE},group_size_for_validation_group_mse{group_size_for_validation_group_mse},
177+
max_training_prediction_or_response{NAN_DOUBLE}, validation_tuning_metric{validation_tuning_metric},
180178
validation_indexes{std::vector<size_t>(0)}
181179
{
182180
}
@@ -192,8 +190,8 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other):
192190
min_observations_in_split{other.min_observations_in_split},ineligible_boosting_steps_added{other.ineligible_boosting_steps_added},
193191
max_eligible_terms{other.max_eligible_terms},number_of_base_terms{other.number_of_base_terms},
194192
feature_importance{other.feature_importance},tweedie_power{other.tweedie_power},min_training_prediction_or_response{other.min_training_prediction_or_response},
195-
max_training_prediction_or_response{other.max_training_prediction_or_response},validation_group_mse{other.validation_group_mse},
196-
group_size_for_validation_group_mse{other.group_size_for_validation_group_mse},validation_indexes{other.validation_indexes}
193+
max_training_prediction_or_response{other.max_training_prediction_or_response},validation_tuning_metric{other.validation_tuning_metric},
194+
validation_indexes{other.validation_indexes}
197195
{
198196
}
199197

@@ -225,7 +223,6 @@ void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &samp
225223
name_terms(X, X_names);
226224
calculate_feature_importance_on_validation_set();
227225
find_min_and_max_training_predictions_or_responses();
228-
calculate_validation_group_mse();
229226
cleanup_after_fit();
230227
}
231228

@@ -447,7 +444,6 @@ void APLRRegressor::scale_training_observations_if_using_log_link_function()
447444
{
448445
scaling_factor_for_log_link_function=1/inverse_scaling_factor;
449446
y_train*=scaling_factor_for_log_link_function;
450-
y_validation*=scaling_factor_for_log_link_function;
451447
}
452448
else
453449
scaling_factor_for_log_link_function=1.0;
@@ -966,7 +962,50 @@ void APLRRegressor::add_new_term(size_t boosting_step)
966962

967963
void APLRRegressor::calculate_and_validate_validation_error(size_t boosting_step)
968964
{
969-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
965+
VectorXd rescaled_predictions_current_validation(0);
966+
bool link_function_is_log{link_function=="log"};
967+
if(link_function_is_log)
968+
{
969+
rescaled_predictions_current_validation = predictions_current_validation / scaling_factor_for_log_link_function;
970+
}
971+
972+
bool using_default{validation_tuning_metric=="default"};
973+
bool using_mse{validation_tuning_metric=="mse"};
974+
bool using_mae{validation_tuning_metric=="mae"};
975+
bool using_rankability{validation_tuning_metric=="rankability"};
976+
if(using_default)
977+
{
978+
if(link_function_is_log)
979+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,rescaled_predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
980+
else
981+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
982+
}
983+
else if(using_mse)
984+
{
985+
if(link_function_is_log)
986+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,rescaled_predictions_current_validation,sample_weight_validation),sample_weight_validation);
987+
else
988+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions_current_validation,sample_weight_validation),sample_weight_validation);
989+
}
990+
else if(using_mae)
991+
{
992+
if(link_function_is_log)
993+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_absolute_errors(y_validation,rescaled_predictions_current_validation,sample_weight_validation),sample_weight_validation);
994+
else
995+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_absolute_errors(y_validation,predictions_current_validation,sample_weight_validation),sample_weight_validation);
996+
}
997+
else if(using_rankability)
998+
{
999+
if(link_function_is_log)
1000+
validation_error_steps[boosting_step]=-calculate_rankability(y_validation,rescaled_predictions_current_validation,sample_weight_validation,random_state);
1001+
else
1002+
validation_error_steps[boosting_step]=-calculate_rankability(y_validation,predictions_current_validation,sample_weight_validation,random_state);
1003+
}
1004+
else
1005+
{
1006+
throw std::runtime_error(validation_tuning_metric + " is an invalid validation_tuning_metric.");
1007+
}
1008+
9701009
bool validation_error_is_invalid{std::isinf(validation_error_steps[boosting_step])};
9711010
if(validation_error_is_invalid)
9721011
{
@@ -1080,7 +1119,6 @@ void APLRRegressor::revert_scaling_if_using_log_link_function()
10801119
if(link_function=="log")
10811120
{
10821121
y_train/=scaling_factor_for_log_link_function;
1083-
y_validation/=scaling_factor_for_log_link_function;
10841122
intercept+=std::log(1/scaling_factor_for_log_link_function);
10851123
for (size_t i = 0; i < static_cast<size_t>(intercept_steps.size()); ++i)
10861124
{
@@ -1196,17 +1234,6 @@ void APLRRegressor::find_min_and_max_training_predictions_or_responses()
11961234
max_training_prediction_or_response=std::min(training_predictions.maxCoeff(), y_train.maxCoeff());
11971235
}
11981236

1199-
void APLRRegressor::calculate_validation_group_mse()
1200-
{
1201-
VectorXd validation_predictions{predict(X_validation,false)};
1202-
VectorXi validation_predictions_sorted_index{sort_indexes_ascending(validation_predictions)};
1203-
VectorXd y_validation_centered{calculate_rolling_centered_mean(y_validation,validation_predictions_sorted_index,group_size_for_validation_group_mse,sample_weight_validation)};
1204-
VectorXd validation_predictions_centered{calculate_rolling_centered_mean(validation_predictions,validation_predictions_sorted_index,group_size_for_validation_group_mse,sample_weight_validation)};
1205-
1206-
VectorXd squared_residuals{(y_validation_centered-validation_predictions_centered).array().pow(2)};
1207-
validation_group_mse = squared_residuals.mean();
1208-
}
1209-
12101237
void APLRRegressor::validate_that_model_can_be_used(const MatrixXd &X)
12111238
{
12121239
if(std::isnan(intercept) || number_of_base_terms==0) throw std::runtime_error("Model must be trained before predict() can be run.");
@@ -1354,9 +1381,9 @@ size_t APLRRegressor::get_m()
13541381
return m;
13551382
}
13561383

1357-
double APLRRegressor::get_validation_group_mse()
1384+
std::string APLRRegressor::get_validation_tuning_metric()
13581385
{
1359-
return validation_group_mse;
1386+
return validation_tuning_metric;
13601387
}
13611388

13621389
std::vector<size_t> APLRRegressor::get_validation_indexes()

0 commit comments

Comments
 (0)