Skip to content

Commit 3df7b6d

Browse files
capping of predictions
1 parent 3a1361d commit 3df7b6d

12 files changed

+56
-19
lines changed

API_REFERENCE.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Specifies the intercept term of the model if you want to predict before doing an
3232
Specifies the maximum number of bins to discretize the data into when searching for the best split. The default value works well according to empirical results. This hyperparameter is intended for reducing computational costs.
3333

3434
#### max_interaction_level (default = 1)
35-
Specifies the maximum allowed depth of interaction terms. ***0*** means that interactions are not allowed. This hyperparameter should be tuned. Please note that occasionally a too high value produces a model that performs poorly on an independent test set despite looking good when tuning hyperparameters, typically because of a few outlier predictions. If this happens then capping of predictions should be considered. Alternatively, ***max_interaction_level*** may be decreased until the problem disappears.
35+
Specifies the maximum allowed depth of interaction terms. ***0*** means that interactions are not allowed. This hyperparameter should be tuned. Please note that occasionally a too high value produces a model that performs poorly on an independent test set despite looking good when tuning hyperparameters, typically because of a few outlier predictions. To alleviate this, the ***predict*** method by default caps predictions to limits calculated on the training data (if you need the model to extrapolate then switch off the default capping). Alternatively, ***max_interaction_level*** may be decreased until the problem disappears.
3636

3737
#### max_interactions (default = 100000)
3838
The maximum number of interactions allowed. A lower value may be used to reduce computational time.
@@ -75,7 +75,7 @@ An optional list of strings containing names for each predictor in ***X***. Nami
7575
An optional list of integers specifying the indexes of observations to be used for validation instead of training. If this is specified then ***validation_ratio*** is not used. Specifying ***validation_set_indexes*** may be useful for example when modelling time series data (you can place more recent observations in the validation set).
7676

7777

78-
## Method: predict(X:npt.ArrayLike)
78+
## Method: predict(X:npt.ArrayLike, cap_predictions_to_minmax_in_training:bool=True)
7979

8080
***Returns a numpy vector containing predictions of the data in X. Requires that the model has been fitted with the fit method.***
8181

@@ -84,6 +84,9 @@ An optional list of integers specifying the indexes of observations to be used f
8484
#### X
8585
A numpy matrix with predictor values.
8686

87+
#### cap_predictions_to_minmax_in_training
88+
If ***True*** then predictions are capped so that they are not less than the minimum and not greater than the maximum prediction in the training dataset. This is recommended especially if ***max_interaction_level*** is high. However, if you need the model to extrapolate then set this parameter to ***False***.
89+
8790

8891
## Method: set_term_names(X_names:List[str])
8992

aplr/aplr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def fit(self, X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np
5050
self.__set_params_cpp()
5151
self.APLRRegressor.fit(X,y,sample_weight,X_names,validation_set_indexes)
5252

53-
def predict(self,X:npt.ArrayLike)->npt.ArrayLike:
54-
return self.APLRRegressor.predict(X)
53+
def predict(self, X:npt.ArrayLike, cap_predictions_to_minmax_in_training:bool=True)->npt.ArrayLike:
54+
return self.APLRRegressor.predict(X, cap_predictions_to_minmax_in_training)
5555

5656
def set_term_names(self, X_names:List[str]):
5757
self.APLRRegressor.set_term_names(X_names)

cpp/APLRRegressor.h

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class APLRRegressor
7575
void find_optimal_m_and_update_model_accordingly();
7676
void name_terms(const MatrixXd &X, const std::vector<std::string> &X_names);
7777
void calculate_feature_importance_on_validation_set();
78+
void find_min_and_max_training_predictions();
7879
void cleanup_after_fit();
7980
void validate_that_model_can_be_used(const MatrixXd &X);
8081
void throw_error_if_family_does_not_exist();
@@ -89,6 +90,7 @@ class APLRRegressor
8990
VectorXd differentiate_predictions();
9091
void scale_training_observations_if_using_log_link_function();
9192
void revert_scaling_if_using_log_link_function();
93+
void cap_predictions_to_minmax_in_training(VectorXd &predictions);
9294

9395
public:
9496
//Fields
@@ -118,6 +120,8 @@ class APLRRegressor
118120
size_t number_of_base_terms;
119121
VectorXd feature_importance; //Populated in fit() using validation set. Rows are in the same order as in X.
120122
double tweedie_power;
123+
double min_training_prediction;
124+
double max_training_prediction;
121125

122126
//Methods
123127
APLRRegressor(size_t m=1000,double v=0.1,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t>::lowest(),std::string family="gaussian",
@@ -127,7 +131,7 @@ class APLRRegressor
127131
APLRRegressor(const APLRRegressor &other);
128132
~APLRRegressor();
129133
void fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight=VectorXd(0),const std::vector<std::string> &X_names={},const std::vector<size_t> &validation_set_indexes={});
130-
VectorXd predict(const MatrixXd &X);
134+
VectorXd predict(const MatrixXd &X, bool cap_predictions_to_minmax_in_training=true);
131135
void set_term_names(const std::vector<std::string> &X_names);
132136
MatrixXd calculate_local_feature_importance(const MatrixXd &X);
133137
MatrixXd calculate_local_feature_importance_for_terms(const MatrixXd &X);
@@ -151,7 +155,8 @@ APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::s
151155
bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
152156
intercept_steps{VectorXd(0)},max_interactions{max_interactions},interactions_eligible{0},validation_error_steps{VectorXd(0)},
153157
min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
154-
max_eligible_terms{max_eligible_terms},number_of_base_terms{0},tweedie_power{tweedie_power}
158+
max_eligible_terms{max_eligible_terms},number_of_base_terms{0},tweedie_power{tweedie_power},min_training_prediction{NAN_DOUBLE},
159+
max_training_prediction{NAN_DOUBLE}
155160
{
156161
}
157162

@@ -165,7 +170,8 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other):
165170
max_interactions{other.max_interactions},interactions_eligible{other.interactions_eligible},validation_error_steps{other.validation_error_steps},
166171
min_observations_in_split{other.min_observations_in_split},ineligible_boosting_steps_added{other.ineligible_boosting_steps_added},
167172
max_eligible_terms{other.max_eligible_terms},number_of_base_terms{other.number_of_base_terms},
168-
feature_importance{other.feature_importance},tweedie_power{other.tweedie_power}
173+
feature_importance{other.feature_importance},tweedie_power{other.tweedie_power},min_training_prediction{other.min_training_prediction},
174+
max_training_prediction{other.max_training_prediction}
169175
{
170176
}
171177

@@ -193,6 +199,7 @@ void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &samp
193199
revert_scaling_if_using_log_link_function();
194200
name_terms(X, X_names);
195201
calculate_feature_importance_on_validation_set();
202+
find_min_and_max_training_predictions();
196203
cleanup_after_fit();
197204
}
198205

@@ -1019,6 +1026,13 @@ MatrixXd APLRRegressor::calculate_local_feature_importance(const MatrixXd &X)
10191026
return output;
10201027
}
10211028

1029+
void APLRRegressor::find_min_and_max_training_predictions()
1030+
{
1031+
VectorXd training_predictions{predict(X_train,false)};
1032+
min_training_prediction=training_predictions.minCoeff();
1033+
max_training_prediction=training_predictions.maxCoeff();
1034+
}
1035+
10221036
void APLRRegressor::validate_that_model_can_be_used(const MatrixXd &X)
10231037
{
10241038
if(std::isnan(intercept) || number_of_base_terms==0) throw std::runtime_error("Model must be trained before predict() can be run.");
@@ -1056,13 +1070,18 @@ void APLRRegressor::cleanup_after_fit()
10561070
}
10571071
}
10581072

1059-
VectorXd APLRRegressor::predict(const MatrixXd &X)
1073+
VectorXd APLRRegressor::predict(const MatrixXd &X, bool cap_predictions_to_minmax_in_training)
10601074
{
10611075
validate_that_model_can_be_used(X);
10621076

10631077
VectorXd linear_predictor{calculate_linear_predictor(X)};
10641078
VectorXd predictions{transform_linear_predictor_to_predictions(linear_predictor,link_function,tweedie_power)};
10651079

1080+
if(cap_predictions_to_minmax_in_training)
1081+
{
1082+
this->cap_predictions_to_minmax_in_training(predictions);
1083+
}
1084+
10661085
return predictions;
10671086
}
10681087

@@ -1077,6 +1096,17 @@ VectorXd APLRRegressor::calculate_linear_predictor(const MatrixXd &X)
10771096
return predictions;
10781097
}
10791098

1099+
void APLRRegressor::cap_predictions_to_minmax_in_training(VectorXd &predictions)
1100+
{
1101+
for (size_t i = 0; i < static_cast<size_t>(predictions.rows()); ++i)
1102+
{
1103+
if(std::isgreater(predictions[i],max_training_prediction))
1104+
predictions[i]=max_training_prediction;
1105+
else if(std::isless(predictions[i],min_training_prediction))
1106+
predictions[i]=min_training_prediction;
1107+
}
1108+
}
1109+
10801110
MatrixXd APLRRegressor::calculate_local_feature_importance_for_terms(const MatrixXd &X)
10811111
{
10821112
validate_that_model_can_be_used(X);

cpp/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ int main()
4949
std::cout<<is_approximately_equal(model.validation_error_steps.minCoeff(),7.02559,0.00001)<<"\n";
5050

5151
std::cout<<"mean prediction "<<predictions.mean()<<"\n\n";
52-
std::cout<<is_approximately_equal(predictions.mean(),23.9213,0.0001)<<"\n";
52+
std::cout<<is_approximately_equal(predictions.mean(),23.9133,0.0001)<<"\n";
5353

5454
std::cout<<"best_m: "<<model.m<<"\n";
5555

cpp/pythonbinding.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ PYBIND11_MODULE(aplr_cpp, m) {
2020
py::arg("tweedie_power")=1.5)
2121
.def("fit", &APLRRegressor::fit,py::arg("X"),py::arg("y"),py::arg("sample_weight")=VectorXd(0),py::arg("X_names")=std::vector<std::string>(),
2222
py::arg("validation_set_indexes")=std::vector<size_t>(),py::call_guard<py::scoped_ostream_redirect,py::scoped_estream_redirect>())
23-
.def("predict", &APLRRegressor::predict,py::arg("X"))
23+
.def("predict", &APLRRegressor::predict,py::arg("X"),py::arg("bool cap_predictions_to_minmax_in_training")=true)
2424
.def("set_term_names", &APLRRegressor::set_term_names,py::arg("X_names"))
2525
.def("calculate_local_feature_importance",&APLRRegressor::calculate_local_feature_importance,py::arg("X"))
2626
.def("calculate_local_feature_importance_for_terms",&APLRRegressor::calculate_local_feature_importance_for_terms,py::arg("X"))
@@ -57,16 +57,18 @@ PYBIND11_MODULE(aplr_cpp, m) {
5757
.def_readwrite("number_of_base_terms",&APLRRegressor::number_of_base_terms)
5858
.def_readwrite("feature_importance",&APLRRegressor::feature_importance)
5959
.def_readwrite("tweedie_power",&APLRRegressor::tweedie_power)
60+
.def_readwrite("min_training_prediction",&APLRRegressor::min_training_prediction)
61+
.def_readwrite("max_training_prediction",&APLRRegressor::max_training_prediction)
6062
.def(py::pickle(
6163
[](const APLRRegressor &a) { // __getstate__
6264
/* Return a tuple that fully encodes the state of the object */
6365
return py::make_tuple(a.m,a.v,a.random_state,a.family,a.n_jobs,a.validation_ratio,a.intercept,a.bins,a.verbosity,
6466
a.max_interaction_level,a.max_interactions,a.validation_error_steps,a.term_names,a.term_coefficients,a.terms,a.intercept_steps,
6567
a.interactions_eligible,a.min_observations_in_split,a.ineligible_boosting_steps_added,a.max_eligible_terms,
66-
a.number_of_base_terms,a.feature_importance,a.link_function,a.tweedie_power);
68+
a.number_of_base_terms,a.feature_importance,a.link_function,a.tweedie_power,a.min_training_prediction,a.max_training_prediction);
6769
},
6870
[](py::tuple t) { // __setstate__
69-
if (t.size() != 24)
71+
if (t.size() != 26)
7072
throw std::runtime_error("Invalid state!");
7173

7274
/* Create a new C++ instance */
@@ -85,6 +87,8 @@ PYBIND11_MODULE(aplr_cpp, m) {
8587
a.max_eligible_terms=t[19].cast<size_t>();
8688
a.number_of_base_terms=t[20].cast<size_t>();
8789
a.feature_importance=t[21].cast<VectorXd>();
90+
a.min_training_prediction=t[24].cast<double>();
91+
a.max_training_prediction=t[25].cast<double>();
8892

8993
return a;
9094
}

cpp/test ALRRegressor gamma.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main()
5252
save_data("data/output.csv",predictions);
5353

5454
std::cout<<predictions.mean()<<"\n\n";
55-
tests.push_back(is_approximately_equal(predictions.mean(),23.6098,0.00001));
55+
tests.push_back(is_approximately_equal(predictions.mean(),23.6065,0.00001));
5656

5757
//std::cout<<model.validation_error_steps<<"\n\n";
5858

cpp/test ALRRegressor inversegaussian.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int main()
5353
save_data("data/output.csv",predictions);
5454

5555
std::cout<<predictions.mean()<<"\n\n";
56-
tests.push_back(is_approximately_equal(predictions.mean(),23.4164,0.00001));
56+
tests.push_back(is_approximately_equal(predictions.mean(),23.4135,0.00001));
5757

5858
//std::cout<<model.validation_error_steps<<"\n\n";
5959

cpp/test ALRRegressor logit.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main()
5252
save_data("data/output.csv",predictions);
5353

5454
std::cout<<predictions.mean()<<"\n\n";
55-
tests.push_back(is_approximately_equal(predictions.mean(),0.104267,0.00001));
55+
tests.push_back(is_approximately_equal(predictions.mean(),0.103821,0.00001));
5656

5757
//std::cout<<model.validation_error_steps<<"\n\n";
5858

cpp/test ALRRegressor poisson.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main()
5252
save_data("data/output.csv",predictions);
5353

5454
std::cout<<predictions.mean()<<"\n\n";
55-
tests.push_back(is_approximately_equal(predictions.mean(),1.89421,0.00001));
55+
tests.push_back(is_approximately_equal(predictions.mean(),1.89378,0.00001));
5656

5757
//std::cout<<model.validation_error_steps<<"\n\n";
5858

cpp/test ALRRegressor poissongamma.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ int main()
5353
save_data("data/output.csv",predictions);
5454

5555
std::cout<<predictions.mean()<<"\n\n";
56-
tests.push_back(is_approximately_equal(predictions.mean(),1.88935,0.00001));
56+
tests.push_back(is_approximately_equal(predictions.mean(),1.88887,0.00001));
5757

5858
//std::cout<<model.validation_error_steps<<"\n\n";
5959

0 commit comments

Comments
 (0)