Skip to content

Commit ca47773

Browse files
deprecation of intercept field and bugfixes
1 parent 5bb158a commit ca47773

File tree

8 files changed

+109
-57
lines changed

8 files changed

+109
-57
lines changed

API_REFERENCE_FOR_REGRESSION.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# APLRRegressor
22

3-
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, loss_function:str="mse", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, dispersion_parameter:float=1.5, validation_tuning_metric:str="default", quantile:float=0.5)
3+
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, loss_function:str="mse", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, dispersion_parameter:float=1.5, validation_tuning_metric:str="default", quantile:float=0.5)
44

55
### Constructor parameters
66

@@ -25,9 +25,6 @@ Multi-threading parameter. If ***0*** then uses all available cores for multi-th
2525
#### validation_ratio (default = 0.2)
2626
The ratio of training observations to use for validation instead of training. The number of boosting steps is automatically tuned to minimize validation error.
2727

28-
#### intercept (default = nan)
29-
Specifies the intercept term of the model if you want to predict before doing any training. However, when the ***fit*** method is run then the intercept is estimated based on data and whatever was specified as ***intercept*** when instantiating ***APLRRegressor*** gets overwritten.
30-
3128
#### bins (default = 300)
3229
Specifies the maximum number of bins to discretize the data into when searching for the best split. The default value works well according to empirical results. This hyperparameter is intended for reducing computational costs. Must be greater than 1.
3330

aplr/aplr.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55

66

77
class APLRRegressor():
8-
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, loss_function:str="mse", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, dispersion_parameter:float=1.5, validation_tuning_metric:str="default", quantile:float=0.5):
8+
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, loss_function:str="mse", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, dispersion_parameter:float=1.5, validation_tuning_metric:str="default", quantile:float=0.5):
99
self.m=m
1010
self.v=v
1111
self.random_state=random_state
1212
self.loss_function=loss_function
1313
self.link_function=link_function
1414
self.n_jobs=n_jobs
1515
self.validation_ratio=validation_ratio
16-
self.intercept=intercept
1716
self.bins=bins
1817
self.max_interaction_level=max_interaction_level
1918
self.max_interactions=max_interactions
@@ -38,7 +37,6 @@ def __set_params_cpp(self):
3837
self.APLRRegressor.link_function=self.link_function
3938
self.APLRRegressor.n_jobs=self.n_jobs
4039
self.APLRRegressor.validation_ratio=self.validation_ratio
41-
self.APLRRegressor.intercept=self.intercept
4240
self.APLRRegressor.bins=self.bins
4341
self.APLRRegressor.max_interaction_level=self.max_interaction_level
4442
self.APLRRegressor.max_interactions=self.max_interactions
@@ -106,7 +104,6 @@ def get_params(self, deep=True):
106104
"link_function":self.link_function,
107105
"n_jobs":self.n_jobs,
108106
"validation_ratio":self.validation_ratio,
109-
"intercept":self.intercept,
110107
"bins":self.bins,
111108
"max_interaction_level":self.max_interaction_level,
112109
"max_interactions":self.max_interactions,

cpp/APLRClassifier.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ void APLRClassifier::fit(const MatrixXd &X,const std::vector<std::string> &y,con
101101
bool two_class_case{categories.size()==2};
102102
if(two_class_case)
103103
{
104-
logit_models[categories[0]] = APLRRegressor(m,v,random_state,"binomial","logit",n_jobs,validation_ratio,NAN_DOUBLE,reserved_terms_times_num_x,
104+
logit_models[categories[0]] = APLRRegressor(m,v,random_state,"binomial","logit",n_jobs,validation_ratio,reserved_terms_times_num_x,
105105
bins,verbosity,max_interaction_level,max_interactions,min_observations_in_split,ineligible_boosting_steps_added,
106106
max_eligible_terms,1.5,"default",0.5);
107107
logit_models[categories[0]].fit(X,response_values[categories[0]],sample_weight,X_names,validation_indexes,prioritized_predictors_indexes,
@@ -114,7 +114,7 @@ void APLRClassifier::fit(const MatrixXd &X,const std::vector<std::string> &y,con
114114
{
115115
for (auto &category:categories)
116116
{
117-
logit_models[category] = APLRRegressor(m,v,random_state,"binomial","logit",n_jobs,validation_ratio,NAN_DOUBLE,reserved_terms_times_num_x,
117+
logit_models[category] = APLRRegressor(m,v,random_state,"binomial","logit",n_jobs,validation_ratio,reserved_terms_times_num_x,
118118
bins,verbosity,max_interaction_level,max_interactions,min_observations_in_split,ineligible_boosting_steps_added,
119119
max_eligible_terms,1.5,"default",0.5);
120120
logit_models[category].fit(X,response_values[category],sample_weight,X_names,validation_indexes,prioritized_predictors_indexes,

cpp/APLRRegressor.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class APLRRegressor
110110
void cap_predictions_to_minmax_in_training(VectorXd &predictions);
111111
std::string compute_raw_base_term_name(const Term &term, const std::string &X_name);
112112
void throw_error_if_m_is_invalid();
113+
bool model_has_not_been_trained();
113114

114115
public:
115116
double intercept;
@@ -143,7 +144,7 @@ class APLRRegressor
143144
double quantile;
144145

145146
APLRRegressor(size_t m=1000,double v=0.1,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t>::lowest(),std::string loss_function="mse",
146-
std::string link_function="identity", size_t n_jobs=0, double validation_ratio=0.2,double intercept=NAN_DOUBLE,
147+
std::string link_function="identity", size_t n_jobs=0, double validation_ratio=0.2,
147148
size_t reserved_terms_times_num_x=100, size_t bins=300,size_t verbosity=0,size_t max_interaction_level=1,size_t max_interactions=100000,
148149
size_t min_observations_in_split=20, size_t ineligible_boosting_steps_added=10, size_t max_eligible_terms=5,double dispersion_parameter=1.5,
149150
std::string validation_tuning_metric="default", double quantile=0.5);
@@ -169,10 +170,10 @@ class APLRRegressor
169170
};
170171

171172
APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::string loss_function,std::string link_function,size_t n_jobs,
172-
double validation_ratio,double intercept,size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,
173+
double validation_ratio,size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,
173174
size_t max_interactions,size_t min_observations_in_split,size_t ineligible_boosting_steps_added,size_t max_eligible_terms,double dispersion_parameter,
174175
std::string validation_tuning_metric, double quantile):
175-
reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{intercept},m{m},v{v},
176+
reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{NAN_DOUBLE},m{m},v{v},
176177
loss_function{loss_function},link_function{link_function},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
177178
bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
178179
max_interactions{max_interactions},interactions_eligible{0},validation_error_steps{VectorXd(0)},
@@ -1274,8 +1275,7 @@ void APLRRegressor::name_terms(const MatrixXd &X, const std::vector<std::string>
12741275

12751276
void APLRRegressor::set_term_names(const std::vector<std::string> &X_names)
12761277
{
1277-
bool model_has_not_been_trained{!std::isfinite(intercept)};
1278-
if(model_has_not_been_trained)
1278+
if(model_has_not_been_trained())
12791279
throw std::runtime_error("The model must be trained with fit() before term names can be set.");
12801280

12811281
for (size_t i = 0; i < terms.size(); ++i)
@@ -1306,6 +1306,11 @@ void APLRRegressor::set_term_names(const std::vector<std::string> &X_names)
13061306
}
13071307
}
13081308

1309+
bool APLRRegressor::model_has_not_been_trained()
1310+
{
1311+
return !std::isfinite(intercept);
1312+
}
1313+
13091314
std::string APLRRegressor::compute_raw_base_term_name(const Term &term, const std::string &X_name)
13101315
{
13111316
std::string name{""};
@@ -1363,7 +1368,7 @@ void APLRRegressor::find_min_and_max_training_predictions_or_responses()
13631368

13641369
void APLRRegressor::validate_that_model_can_be_used(const MatrixXd &X)
13651370
{
1366-
if(std::isnan(intercept) || number_of_base_terms==0) throw std::runtime_error("Model must be trained before predict() can be run.");
1371+
if(model_has_not_been_trained()) throw std::runtime_error("The model must be trained with fit() before predict() can be run.");
13671372
if(X.rows()==0) throw std::runtime_error("X cannot have zero rows.");
13681373
size_t cols_provided{static_cast<size_t>(X.cols())};
13691374
if(cols_provided!=number_of_base_terms) throw std::runtime_error("X must have "+std::to_string(number_of_base_terms) +" columns but "+std::to_string(cols_provided)+" were provided.");

0 commit comments

Comments
 (0)