Skip to content

Commit ff83ad8

Browse files
added possibility to compare models with different family or tweedie_power parameters
1 parent 7715e27 commit ff83ad8

File tree

10 files changed

+183
-25
lines changed

10 files changed

+183
-25
lines changed

API_REFERENCE.md

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# APLRRegressor
22

3-
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5)
3+
## class aplr.APLRRegressor(m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5, group_size_for_validation_group_mse:int=100)
44

55
### Constructor parameters
66

@@ -14,7 +14,7 @@ The learning rate. Must be greater than zero and not more than one. The higher t
1414
Used to randomly split training observations into training and validation if ***validation_set_indexes*** is not specified when fitting.
1515

1616
#### family (default = "gaussian")
17-
Determines the loss function used. Allowed values are "gaussian", "binomial", "poisson", "gamma" and "tweedie". This is used together with ***link_function***. Please note that this is not a tuning parameter because it defines how the loss function is calculated.
17+
Determines the loss function used. Allowed values are "gaussian", "binomial", "poisson", "gamma" and "tweedie". This is used together with ***link_function***. Please note that this is not a tuning parameter because it defines how the loss function is calculated. However, it can be tuned with ***get_validation_group_mse()*** as the tuning metric.
1818

1919
#### link_function (default = "identity")
2020
Determines how the linear predictor is transformed to predictions. Allowed values are "identity", "logit" and "log". For an ordinary regression model use ***family*** "gaussian" and ***link_function*** "identity". For logistic regression use ***family*** "binomial" and ***link_function*** "logit". For a multiplicative model use the "log" ***link_function***. The "log" ***link_function*** often works best with a "poisson", "gamma" or "tweedie" ***family***, depending on the data. The ***family*** "poisson", "gamma" or "tweedie" should only be used with the "log" ***link_function***. Inappropriate combinations of ***family*** and ***link_function*** may result in a warning message when fitting the model and/or a poor model fit. Please note that values other than "identity" typically require a significantly higher ***m*** (or ***v***) in order to converge.
@@ -50,7 +50,10 @@ Limits 1) the number of terms already in the model that can be considered as int
5050
***0*** does not print progress reports during fitting. ***1*** prints a summary after running the ***fit*** method. ***2*** prints a summary after each boosting step.
5151

5252
#### tweedie_power (default = 1.5)
53-
Species the variance power for the "tweedie" ***family*** and ***link_function***. Please note that this is not a tuning parameter because it defines how the loss function is calculated.
53+
Species the variance power for the "tweedie" ***family*** and ***link_function***. Please note that this is not a tuning parameter because it defines how the loss function is calculated. However, it can be tuned with ***get_validation_group_mse()*** as the tuning metric.
54+
55+
#### group_size_for_validation_group_mse (default = 100)
56+
APLR calculates mean squared error on grouped data in the validation set. This can be useful for comparing models that have different ***family*** or ***tweedie_power*** parameters. The maximum number of observations in each group is specified by ***group_size_for_validation_group_mse***. Some of the observations with the lowest or highest response values will belong to groups with less than ***group_size_for_validation_group_mse*** observations. The minimum number of observations in a group is ***group_size_for_validation_group_mse/2***. If ***group_size_for_validation_group_mse*** is equal to or higher than the number of observations in the validation set, then there will only be one group (in this case the grouped validation MSE is not so useful). ***group_size_for_validation_group_mse*** should be large enough so that the Central Limit Theorem holds (at least 60, but 100 is a safer choice). Also, the number of observations in the validation set should be substantially higher than ***group_size_for_validation_group_mse*** for group validation MSE to be useful.
5457

5558

5659
## Method: fit(X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np.empty(0), X_names:List[str]=[], validation_set_indexes:List[int]=[])
@@ -170,4 +173,9 @@ The index of the term selected. So ***0*** is the first term, ***1*** is the sec
170173

171174
## Method: get_m()
172175

173-
***Returns the number of boosting steps in the model (the value that minimized validation error).***
176+
***Returns the number of boosting steps in the model (the value that minimized validation error).***
177+
178+
179+
## Method: get_validation_group_mse()
180+
181+
***Returns mean squared error on grouped data in the validation set.*** See ***group_size_for_validation_group_mse*** for more information.

aplr/aplr.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class APLRRegressor():
8-
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5):
8+
def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaussian", link_function:str="identity", n_jobs:int=0, validation_ratio:float=0.2, intercept:float=np.nan, bins:int=300, max_interaction_level:int=1, max_interactions:int=100000, min_observations_in_split:int=20, ineligible_boosting_steps_added:int=10, max_eligible_terms:int=5, verbosity:int=0, tweedie_power:float=1.5, group_size_for_validation_group_mse:int=100):
99
self.m=m
1010
self.v=v
1111
self.random_state=random_state
@@ -22,6 +22,7 @@ def __init__(self, m:int=1000, v:float=0.1, random_state:int=0, family:str="gaus
2222
self.max_eligible_terms=max_eligible_terms
2323
self.verbosity=verbosity
2424
self.tweedie_power=tweedie_power
25+
self.group_size_for_validation_group_mse=group_size_for_validation_group_mse
2526

2627
#Creating aplr_cpp and setting parameters
2728
self.APLRRegressor=aplr_cpp.APLRRegressor()
@@ -45,6 +46,7 @@ def __set_params_cpp(self):
4546
self.APLRRegressor.max_eligible_terms=self.max_eligible_terms
4647
self.APLRRegressor.verbosity=self.verbosity
4748
self.APLRRegressor.tweedie_power=self.tweedie_power
49+
self.APLRRegressor.group_size_for_validation_group_mse=self.group_size_for_validation_group_mse
4850

4951
def fit(self, X:npt.ArrayLike, y:npt.ArrayLike, sample_weight:npt.ArrayLike = np.empty(0), X_names:List[str]=[], validation_set_indexes:List[int]=[]):
5052
self.__set_params_cpp()
@@ -89,9 +91,12 @@ def get_intercept_steps(self)->npt.ArrayLike:
8991
def get_m(self)->int:
9092
return self.APLRRegressor.get_m()
9193

94+
def get_validation_group_mse(self)->float:
95+
return self.APLRRegressor.get_validation_group_mse()
96+
9297
#For sklearn
9398
def get_params(self, deep=True):
94-
return {"m": self.m, "v": self.v,"random_state":self.random_state,"family":self.family,"link_function":self.link_function,"n_jobs":self.n_jobs,"validation_ratio":self.validation_ratio,"intercept":self.intercept,"bins":self.bins,"max_interaction_level":self.max_interaction_level,"max_interactions":self.max_interactions,"verbosity":self.verbosity,"min_observations_in_split":self.min_observations_in_split,"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,"max_eligible_terms":self.max_eligible_terms,"tweedie_power":self.tweedie_power}
99+
return {"m": self.m, "v": self.v,"random_state":self.random_state,"family":self.family,"link_function":self.link_function,"n_jobs":self.n_jobs,"validation_ratio":self.validation_ratio,"intercept":self.intercept,"bins":self.bins,"max_interaction_level":self.max_interaction_level,"max_interactions":self.max_interactions,"verbosity":self.verbosity,"min_observations_in_split":self.min_observations_in_split,"ineligible_boosting_steps_added":self.ineligible_boosting_steps_added,"max_eligible_terms":self.max_eligible_terms,"tweedie_power":self.tweedie_power,"group_size_for_validation_group_mse":self.group_size_for_validation_group_mse}
95100

96101
#For sklearn
97102
def set_params(self, **parameters):

cpp/APLRRegressor.h

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class APLRRegressor
7676
void name_terms(const MatrixXd &X, const std::vector<std::string> &X_names);
7777
void calculate_feature_importance_on_validation_set();
7878
void find_min_and_max_training_predictions();
79+
void calculate_validation_group_mse();
7980
void cleanup_after_fit();
8081
void validate_that_model_can_be_used(const MatrixXd &X);
8182
void throw_error_if_family_does_not_exist();
@@ -122,12 +123,15 @@ class APLRRegressor
122123
double tweedie_power;
123124
double min_training_prediction;
124125
double max_training_prediction;
126+
double validation_group_mse;
127+
size_t group_size_for_validation_group_mse;
125128

126129
//Methods
127130
APLRRegressor(size_t m=1000,double v=0.1,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t>::lowest(),std::string family="gaussian",
128131
std::string link_function="identity", size_t n_jobs=0, double validation_ratio=0.2,double intercept=NAN_DOUBLE,
129132
size_t reserved_terms_times_num_x=100, size_t bins=300,size_t verbosity=0,size_t max_interaction_level=1,size_t max_interactions=100000,
130-
size_t min_observations_in_split=20, size_t ineligible_boosting_steps_added=10, size_t max_eligible_terms=5,double tweedie_power=1.5);
133+
size_t min_observations_in_split=20, size_t ineligible_boosting_steps_added=10, size_t max_eligible_terms=5,double tweedie_power=1.5,
134+
size_t group_size_for_validation_group_mse=100);
131135
APLRRegressor(const APLRRegressor &other);
132136
~APLRRegressor();
133137
void fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight=VectorXd(0),const std::vector<std::string> &X_names={},const std::vector<size_t> &validation_set_indexes={});
@@ -144,19 +148,21 @@ class APLRRegressor
144148
double get_intercept();
145149
VectorXd get_intercept_steps();
146150
size_t get_m();
151+
double get_validation_group_mse();
147152
};
148153

149154
//Regular constructor
150155
APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::string family,std::string link_function,size_t n_jobs,
151156
double validation_ratio,double intercept,size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,
152-
size_t max_interactions,size_t min_observations_in_split,size_t ineligible_boosting_steps_added,size_t max_eligible_terms,double tweedie_power):
157+
size_t max_interactions,size_t min_observations_in_split,size_t ineligible_boosting_steps_added,size_t max_eligible_terms,double tweedie_power,
158+
size_t group_size_for_validation_group_mse):
153159
reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{intercept},m{m},v{v},
154160
family{family},link_function{link_function},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
155161
bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
156162
intercept_steps{VectorXd(0)},max_interactions{max_interactions},interactions_eligible{0},validation_error_steps{VectorXd(0)},
157163
min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
158164
max_eligible_terms{max_eligible_terms},number_of_base_terms{0},tweedie_power{tweedie_power},min_training_prediction{NAN_DOUBLE},
159-
max_training_prediction{NAN_DOUBLE}
165+
max_training_prediction{NAN_DOUBLE},validation_group_mse{NAN_DOUBLE},group_size_for_validation_group_mse{group_size_for_validation_group_mse}
160166
{
161167
}
162168

@@ -171,7 +177,8 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other):
171177
min_observations_in_split{other.min_observations_in_split},ineligible_boosting_steps_added{other.ineligible_boosting_steps_added},
172178
max_eligible_terms{other.max_eligible_terms},number_of_base_terms{other.number_of_base_terms},
173179
feature_importance{other.feature_importance},tweedie_power{other.tweedie_power},min_training_prediction{other.min_training_prediction},
174-
max_training_prediction{other.max_training_prediction}
180+
max_training_prediction{other.max_training_prediction},validation_group_mse{other.validation_group_mse},
181+
group_size_for_validation_group_mse{other.group_size_for_validation_group_mse}
175182
{
176183
}
177184

@@ -200,6 +207,7 @@ void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &samp
200207
name_terms(X, X_names);
201208
calculate_feature_importance_on_validation_set();
202209
find_min_and_max_training_predictions();
210+
calculate_validation_group_mse();
203211
cleanup_after_fit();
204212
}
205213

@@ -1042,6 +1050,17 @@ void APLRRegressor::find_min_and_max_training_predictions()
10421050
max_training_prediction=training_predictions.maxCoeff();
10431051
}
10441052

1053+
void APLRRegressor::calculate_validation_group_mse()
1054+
{
1055+
VectorXd validation_predictions{predict(X_validation,false)};
1056+
VectorXi y_validation_sorted_index{sort_indexes_ascending(y_validation)};
1057+
VectorXd y_validation_centered{calculate_rolling_centered_mean(y_validation,y_validation_sorted_index,group_size_for_validation_group_mse,sample_weight_validation)};
1058+
VectorXd validation_predictions_centered{calculate_rolling_centered_mean(validation_predictions,y_validation_sorted_index,group_size_for_validation_group_mse,sample_weight_validation)};
1059+
1060+
VectorXd squared_residuals{(y_validation_centered-validation_predictions_centered).array().pow(2)};
1061+
validation_group_mse = squared_residuals.mean();
1062+
}
1063+
10451064
void APLRRegressor::validate_that_model_can_be_used(const MatrixXd &X)
10461065
{
10471066
if(std::isnan(intercept) || number_of_base_terms==0) throw std::runtime_error("Model must be trained before predict() can be run.");
@@ -1186,4 +1205,9 @@ VectorXd APLRRegressor::get_intercept_steps()
11861205
size_t APLRRegressor::get_m()
11871206
{
11881207
return m;
1208+
}
1209+
1210+
double APLRRegressor::get_validation_group_mse()
1211+
{
1212+
return validation_group_mse;
11891213
}

cpp/functions.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,4 +310,65 @@ void throw_error_if_matrix_has_nan_or_infinite_elements(const T &x, const std::s
310310
{
311311
throw std::runtime_error(matrix_name + " has nan or infinite elements.");
312312
}
313+
}
314+
315+
VectorXd calculate_rolling_centered_mean(const VectorXd &vector, const VectorXi &sorted_index, size_t rolling_window, const VectorXd &sample_weight=VectorXd(0))
316+
{
317+
bool sample_weight_is_provided{sample_weight.rows()==vector.rows()};
318+
bool rolling_window_contains_one_observation{rolling_window<=1};
319+
bool rolling_window_encompasses_all_observations_in_validation_set{rolling_window >= static_cast<size_t>(vector.rows())};
320+
size_t half_rolling_window{(rolling_window-1)/2};
321+
322+
VectorXd rolling_centered_mean;
323+
if(rolling_window_contains_one_observation)
324+
rolling_centered_mean = vector;
325+
else if(rolling_window_encompasses_all_observations_in_validation_set)
326+
{
327+
if(sample_weight_is_provided)
328+
{
329+
double weighted_centered_mean{(vector.array() * sample_weight.array()).sum() / sample_weight.sum()};
330+
rolling_centered_mean = VectorXd::Constant(vector.rows(),weighted_centered_mean);
331+
}
332+
else
333+
rolling_centered_mean = VectorXd::Constant(vector.rows(),vector.mean());
334+
}
335+
else
336+
{
337+
rolling_centered_mean = VectorXd::Constant(vector.rows(),0);
338+
339+
size_t vector_size{static_cast<size_t>(sorted_index.rows())};
340+
for (size_t i = 0; i < vector_size; ++i)
341+
{
342+
size_t min_index;
343+
if(i<half_rolling_window)
344+
min_index=0;
345+
else
346+
min_index=i-half_rolling_window;
347+
348+
size_t max_index{std::min(vector_size-1, i+half_rolling_window)};
349+
350+
double rolling_centered_weighted_sum{0};
351+
if(sample_weight_is_provided)
352+
{
353+
double rolling_centered_sample_weight_sum{0};
354+
for (size_t j = min_index; j <= max_index; ++j)
355+
{
356+
rolling_centered_weighted_sum += vector[sorted_index[j]] * sample_weight[sorted_index[j]];
357+
rolling_centered_sample_weight_sum += sample_weight[sorted_index[j]];
358+
}
359+
rolling_centered_mean[sorted_index[i]] = rolling_centered_weighted_sum / rolling_centered_sample_weight_sum;
360+
}
361+
else
362+
{
363+
size_t observations{max_index-min_index+1};
364+
for (size_t j = min_index; j <= max_index; ++j)
365+
{
366+
rolling_centered_mean[sorted_index[i]] += vector[sorted_index[j]];
367+
}
368+
rolling_centered_mean[sorted_index[i]] /= observations;
369+
}
370+
}
371+
}
372+
373+
return rolling_centered_mean;
313374
}

0 commit comments

Comments
 (0)