Skip to content

Commit f8fa60d

Browse files
separated family into family and link functions. added more link functions
1 parent 650e2f7 commit f8fa60d

13 files changed

+106
-52
lines changed

aplr/aplr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66

77
class APLRRegressor():
8-
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=100, max_interactions:int=0, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0):
8+
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=100, max_interactions:int=0, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0):
99
self.m=m
1010
self.v=v
1111
self.random_state=random_state
1212
self.family=family
13+
self.link_function=link_function
1314
self.n_jobs=n_jobs
1415
self.validation_ratio=validation_ratio
1516
self.intercept=intercept
@@ -31,6 +32,7 @@ def __set_params_cpp(self):
3132
self.APLRRegressor.v=self.v
3233
self.APLRRegressor.random_state=self.random_state
3334
self.APLRRegressor.family=self.family
35+
self.APLRRegressor.link_function=self.link_function
3436
self.APLRRegressor.n_jobs=self.n_jobs
3537
self.APLRRegressor.validation_ratio=self.validation_ratio
3638
self.APLRRegressor.intercept=self.intercept
@@ -87,7 +89,7 @@ def get_m(self)->int:
8789

8890
#For sklearn
8991
def get_params(self, deep=True):
90-
return {"m": self.m, "v": self.v,"random_state":self.random_state,"family":self.family,"n_jobs":self.n_jobs,"validation_ratio":self.validation_ratio,"intercept":self.intercept,"bins":self.bins,"max_interaction_level":self.max_interaction_level,"max_interactions":self.max_interactions,"verbosity":self.verbosity,"min_observations_in_split":self.min_observations_in_split,"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,"max_eligible_terms":self.max_eligible_terms}
92+
return {"m": self.m, "v": self.v,"random_state":self.random_state,"family":self.family,"link_function":self.link_function,"n_jobs":self.n_jobs,"validation_ratio":self.validation_ratio,"intercept":self.intercept,"bins":self.bins,"max_interaction_level":self.max_interaction_level,"max_interactions":self.max_interactions,"verbosity":self.verbosity,"min_observations_in_split":self.min_observations_in_split,"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,"max_eligible_terms":self.max_eligible_terms}
9193

9294
#For sklearn
9395
def set_params(self, **parameters):

cpp/APLRRegressor.h

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ class APLRRegressor
7777
void cleanup_after_fit();
7878
void validate_that_model_can_be_used(const MatrixXd &X);
7979
void throw_error_if_family_does_not_exist();
80+
void throw_error_if_link_function_does_not_exist();
8081
VectorXd calculate_linear_predictor(const MatrixXd &X);
8182
void update_linear_predictor_and_predictors();
8283
void throw_error_if_response_contains_invalid_values(const VectorXd &y);
@@ -91,6 +92,7 @@ class APLRRegressor
9192
size_t m; //Boosting steps to run. Can shrink to auto tuned value after running fit().
9293
double v; //Learning rate.
9394
std::string family;
95+
std::string link_function;
9496
double validation_ratio;
9597
size_t n_jobs; //0:using all available cores. 1:no multithreading. >1: Using a specified number of cores but not more than is available.
9698
uint_fast32_t random_state; //For train/validation split. If std::numeric_limits<uint_fast32_t>::lowest() then will randomly set a seed
@@ -112,10 +114,10 @@ class APLRRegressor
112114
VectorXd feature_importance; //Populated in fit() using validation set. Rows are in the same order as in X.
113115

114116
//Methods
115-
APLRRegressor(size_t m=1000,double v=0.1,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t>::lowest(),std::string family="gaussian",size_t n_jobs=0,
116-
double validation_ratio=0.2,double intercept=NAN_DOUBLE,size_t reserved_terms_times_num_x=100,size_t bins=300,size_t verbosity=0,
117-
size_t max_interaction_level=100,size_t max_interactions=0,size_t min_observations_in_split=20,size_t ineligible_boosting_steps_added=10,
118-
size_t max_eligible_terms=5);
117+
APLRRegressor(size_t m=1000,double v=0.1,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t>::lowest(),std::string family="gaussian",
118+
std::string link_function="identity", size_t n_jobs=0, double validation_ratio=0.2,double intercept=NAN_DOUBLE,
119+
size_t reserved_terms_times_num_x=100, size_t bins=300,size_t verbosity=0,size_t max_interaction_level=100,size_t max_interactions=0,
120+
size_t min_observations_in_split=20, size_t ineligible_boosting_steps_added=10, size_t max_eligible_terms=5);
119121
APLRRegressor(const APLRRegressor &other);
120122
~APLRRegressor();
121123
void fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight=VectorXd(0),const std::vector<std::string> &X_names={},const std::vector<size_t> &validation_set_indexes={});
@@ -135,11 +137,11 @@ class APLRRegressor
135137
};
136138

137139
//Regular constructor
138-
APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::string family,size_t n_jobs,double validation_ratio,double intercept,
139-
size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,size_t max_interactions,size_t min_observations_in_split,
140-
size_t ineligible_boosting_steps_added,size_t max_eligible_terms):
140+
APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::string family,std::string link_function,size_t n_jobs,
141+
double validation_ratio,double intercept,size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,
142+
size_t max_interactions,size_t min_observations_in_split,size_t ineligible_boosting_steps_added,size_t max_eligible_terms):
141143
reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{intercept},m{m},v{v},
142-
family{family},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
144+
family{family},link_function{link_function},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
143145
bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
144146
intercept_steps{VectorXd(0)},max_interactions{max_interactions},interactions_eligible{0},validation_error_steps{VectorXd(0)},
145147
min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
@@ -150,7 +152,7 @@ APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::s
150152
//Copy constructor
151153
APLRRegressor::APLRRegressor(const APLRRegressor &other):
152154
reserved_terms_times_num_x{other.reserved_terms_times_num_x},intercept{other.intercept},terms{other.terms},m{other.m},v{other.v},
153-
family{other.family},validation_ratio{other.validation_ratio},
155+
family{other.family},link_function{other.link_function},validation_ratio{other.validation_ratio},
154156
n_jobs{other.n_jobs},random_state{other.random_state},bins{other.bins},
155157
verbosity{other.verbosity},term_names{other.term_names},term_coefficients{other.term_coefficients},
156158
max_interaction_level{other.max_interaction_level},intercept_steps{other.intercept_steps},
@@ -173,6 +175,7 @@ APLRRegressor::~APLRRegressor()
173175
void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight,const std::vector<std::string> &X_names,const std::vector<size_t> &validation_set_indexes)
174176
{
175177
throw_error_if_family_does_not_exist();
178+
throw_error_if_link_function_does_not_exist();
176179
validate_input_to_fit(X,y,sample_weight,X_names,validation_set_indexes);
177180
define_training_and_validation_sets(X,y,sample_weight,validation_set_indexes);
178181
initialize();
@@ -212,11 +215,11 @@ void APLRRegressor::throw_error_if_validation_set_indexes_has_invalid_indexes(co
212215

213216
void APLRRegressor::throw_error_if_response_contains_invalid_values(const VectorXd &y)
214217
{
215-
if(family=="logit")
218+
if(link_function=="logit")
216219
throw_error_if_response_is_not_between_0_and_1(y);
217-
else if(family=="poisson" || family=="poissongamma")
220+
else if(link_function=="log" || link_function=="inverseroot")
218221
throw_error_if_response_is_negative(y);
219-
else if(family=="gamma" || family=="inversegaussian")
222+
else if(link_function=="inverse" || link_function=="inversesquare")
220223
throw_error_if_response_is_not_greater_than_zero(y);
221224
}
222225

@@ -225,21 +228,21 @@ void APLRRegressor::throw_error_if_response_is_not_between_0_and_1(const VectorX
225228
bool response_is_less_than_zero{(y.array()<0.0).any()};
226229
bool response_is_greater_than_one{(y.array()>1.0).any()};
227230
if(response_is_less_than_zero || response_is_greater_than_one)
228-
throw std::runtime_error("Response values for "+family+" models cannot be less than zero or greater than one.");
231+
throw std::runtime_error("Response values for "+link_function+" link functions cannot be less than zero or greater than one.");
229232
}
230233

231234
void APLRRegressor::throw_error_if_response_is_negative(const VectorXd &y)
232235
{
233236
bool response_is_less_than_zero{(y.array()<0.0).any()};
234237
if(response_is_less_than_zero)
235-
throw std::runtime_error("Response values for "+family+" models cannot be less than zero.");
238+
throw std::runtime_error("Response values for "+link_function+" link functions cannot be less than zero.");
236239
}
237240

238241
void APLRRegressor::throw_error_if_response_is_not_greater_than_zero(const VectorXd &y)
239242
{
240243
bool response_is_not_greater_than_zero{(y.array()<=0.0).any()};
241244
if(response_is_not_greater_than_zero)
242-
throw std::runtime_error("Response values for "+family+" models must be greater than zero.");
245+
throw std::runtime_error("Response values for "+link_function+" link functions must be greater than zero.");
243246

244247
}
245248

@@ -337,11 +340,11 @@ void APLRRegressor::initialize()
337340
}
338341
}
339342

340-
linear_predictor_current=VectorXd::Constant(y_train.size(),0);
343+
linear_predictor_current=VectorXd::Constant(y_train.size(),intercept);
341344
linear_predictor_null_model=linear_predictor_current;
342-
linear_predictor_current_validation=VectorXd::Constant(y_validation.size(),0);
343-
predictions_current=transform_linear_predictor_to_predictions(linear_predictor_current,family);
344-
predictions_current_validation=transform_linear_predictor_to_predictions(linear_predictor_current_validation,family);
345+
linear_predictor_current_validation=VectorXd::Constant(y_validation.size(),intercept);
346+
predictions_current=transform_linear_predictor_to_predictions(linear_predictor_current,link_function);
347+
predictions_current_validation=transform_linear_predictor_to_predictions(linear_predictor_current_validation,link_function);
345348

346349
validation_error_steps.resize(m);
347350
validation_error_steps.setConstant(std::numeric_limits<double>::infinity());
@@ -379,14 +382,14 @@ VectorXd APLRRegressor::calculate_neg_gradient_current(const VectorXd &y,const V
379382
VectorXd output;
380383
if(family=="gaussian")
381384
output=y-predictions_current;
382-
else if(family=="logit")
385+
else if(family=="binomial")
383386
output=y.array() / predictions_current.array() - (y.array()-1.0) / (predictions_current.array()-1.0);
384387
else if(family=="poisson")
385388
output=y.array() / predictions_current.array() - 1;
386389
else if(family=="gamma")
387390
output=(y.array() - predictions_current.array()) / predictions_current.array() / predictions_current.array();
388391
else if(family=="poissongamma")
389-
output=(y.array() / predictions_current.array().pow(1.5) - predictions_current.array().pow(-0.5));
392+
output=y.array() / predictions_current.array().pow(1.5) - predictions_current.array().pow(-0.5);
390393
else if(family=="inversegaussian")
391394
output=y.array() / predictions_current.array().pow(3.0) - predictions_current.array().pow(-2.0);
392395
return output;
@@ -692,8 +695,9 @@ void APLRRegressor::select_the_best_term_and_update_errors(size_t boosting_step)
692695
if(validation_error_is_invalid)
693696
{
694697
abort_boosting=true;
695-
std::string warning_message{"Warning: Encountered numerical problems when calculating prediction errors."};
696-
if(family=="poisson" || family=="poissongamma" ||family=="gamma" || family=="inversegaussian")
698+
std::string warning_message{"Warning: Encountered numerical problems when calculating prediction errors in the previous boosting step. Not continuing with further boosting steps."};
699+
bool show_additional_warning{family=="poisson" || family=="poissongamma" || family=="gamma" || family=="inversegaussian" || (link_function!="identity" && link_function!="logit")};
700+
if(show_additional_warning)
697701
warning_message+=" A reason may be too large response values.";
698702
std::cout<<warning_message<<"\n";
699703
}
@@ -703,8 +707,8 @@ void APLRRegressor::update_linear_predictor_and_predictors()
703707
{
704708
linear_predictor_current+=linear_predictor_update;
705709
linear_predictor_current_validation+=linear_predictor_update_validation;
706-
predictions_current=transform_linear_predictor_to_predictions(linear_predictor_current,family);
707-
predictions_current_validation=transform_linear_predictor_to_predictions(linear_predictor_current_validation,family);
710+
predictions_current=transform_linear_predictor_to_predictions(linear_predictor_current,link_function);
711+
predictions_current_validation=transform_linear_predictor_to_predictions(linear_predictor_current_validation,link_function);
708712
}
709713

710714
void APLRRegressor::update_gradient_and_errors()
@@ -960,7 +964,7 @@ VectorXd APLRRegressor::predict(const MatrixXd &X)
960964
validate_that_model_can_be_used(X);
961965

962966
VectorXd linear_predictor{calculate_linear_predictor(X)};
963-
VectorXd predictions{transform_linear_predictor_to_predictions(linear_predictor,family)};
967+
VectorXd predictions{transform_linear_predictor_to_predictions(linear_predictor,link_function)};
964968

965969
return predictions;
966970
}
@@ -1053,7 +1057,7 @@ void APLRRegressor::throw_error_if_family_does_not_exist()
10531057
bool family_exists{false};
10541058
if(family=="gaussian")
10551059
family_exists=true;
1056-
else if(family=="logit")
1060+
else if(family=="binomial")
10571061
family_exists=true;
10581062
else if(family=="poisson")
10591063
family_exists=true;
@@ -1065,4 +1069,23 @@ void APLRRegressor::throw_error_if_family_does_not_exist()
10651069
family_exists=true;
10661070
if(!family_exists)
10671071
throw std::runtime_error("Family "+family+" is not available in APLR.");
1072+
}
1073+
1074+
void APLRRegressor::throw_error_if_link_function_does_not_exist()
1075+
{
1076+
bool link_function_exists{false};
1077+
if(link_function=="identity")
1078+
link_function_exists=true;
1079+
else if(link_function=="logit")
1080+
link_function_exists=true;
1081+
else if(link_function=="log")
1082+
link_function_exists=true;
1083+
else if(link_function=="inverseroot")
1084+
link_function_exists=true;
1085+
else if(link_function=="inverse")
1086+
link_function_exists=true;
1087+
else if(link_function=="inversesquare")
1088+
link_function_exists=true;
1089+
if(!link_function_exists)
1090+
throw std::runtime_error("Link function "+link_function+" is not available in APLR.");
10681091
}

cpp/constants.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#pragma once
22
#include <limits>
33

4-
const double NAN_DOUBLE{ std::numeric_limits<double>::quiet_NaN() };
4+
const double NAN_DOUBLE{ std::numeric_limits<double>::quiet_NaN() };
5+
const double SMALL_NEGATIVE_VALUE{-0.000001};
6+
const double SMALL_POSITIVE_VALUE{0.000001};

cpp/functions.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ VectorXd calculate_gaussian_errors(const VectorXd &y,const VectorXd &predicted)
4747
return errors;
4848
}
4949

50-
VectorXd calculate_logit_errors(const VectorXd &y,const VectorXd &predicted)
50+
VectorXd calculate_binomial_errors(const VectorXd &y,const VectorXd &predicted)
5151
{
5252
VectorXd errors{-y.array() * predicted.array().log() - (1.0-y.array()).array() * (1.0-predicted.array()).log()};
5353
return errors;
@@ -84,8 +84,8 @@ VectorXd calculate_errors(const VectorXd &y,const VectorXd &predicted,const Vect
8484
VectorXd errors;
8585
if(family=="gaussian")
8686
errors=calculate_gaussian_errors(y,predicted);
87-
else if(family=="logit")
88-
errors=calculate_logit_errors(y,predicted);
87+
else if(family=="binomial")
88+
errors=calculate_binomial_errors(y,predicted);
8989
else if(family=="poisson")
9090
errors=calculate_poisson_errors(y,predicted);
9191
else if(family=="gamma")
@@ -135,17 +135,35 @@ double calculate_error(const VectorXd &errors,const VectorXd &sample_weight=Vect
135135
return error;
136136
}
137137

138-
VectorXd transform_linear_predictor_to_predictions(const VectorXd &linear_predictor, const std::string &family="gaussian")
138+
VectorXd transform_linear_predictor_to_negative(const VectorXd &linear_predictor)
139139
{
140-
if(family=="gaussian")
140+
VectorXd transformed_linear_predictor{linear_predictor};
141+
for (size_t i = 0; i < static_cast<size_t>(transformed_linear_predictor.rows()); ++i)
142+
{
143+
bool row_is_positive{std::isgreaterequal(transformed_linear_predictor[i],0.0)};
144+
if(row_is_positive)
145+
transformed_linear_predictor[i]=SMALL_NEGATIVE_VALUE;
146+
}
147+
return transformed_linear_predictor;
148+
}
149+
150+
VectorXd transform_linear_predictor_to_predictions(const VectorXd &linear_predictor, const std::string &link_function="identity")
151+
{
152+
if(link_function=="identity")
141153
return linear_predictor;
142-
else if(family=="logit")
154+
else if(link_function=="logit")
143155
{
144156
VectorXd exp_of_linear_predictor{linear_predictor.array().exp()};
145157
return exp_of_linear_predictor.array() / (1.0 + exp_of_linear_predictor.array());
146158
}
147-
else if(family=="poisson" || family=="poissongamma" || family=="gamma" || family=="inversegaussian")
159+
else if(link_function=="log")
148160
return linear_predictor.array().exp();
161+
else if(link_function=="inverseroot")
162+
return 4.0 * transform_linear_predictor_to_negative(linear_predictor).array().pow(-2);
163+
else if(link_function=="inverse")
164+
return -1.0 / transform_linear_predictor_to_negative(linear_predictor).array();
165+
else if(link_function=="inversesquare")
166+
return 1.0 / std::sqrt(2) / (-transform_linear_predictor_to_negative(linear_predictor).array()).pow(0.5);
149167
return VectorXd(0);
150168
}
151169

0 commit comments

Comments
 (0)