Skip to content

Commit bbcb663

Browse files
new and improved fitting with other families and link functions
1 parent 06693b3 commit bbcb663

17 files changed

+111
-104
lines changed

API_REFERENCE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Used to randomly split training observations into training and validation if ***
1717
Determines the loss function used. Allowed values are "gaussian", "binomial", "poisson", "gamma" and "tweedie". This is used together with ***link_function***. Please note that the implementation of values other than "gaussian" is experimental.
1818

1919
#### link_function (default = "identity")
20-
Determines how the linear predictor is transformed to predictions. Allowed values are "identity", "logit", "log", "inverse" and "tweedie". These are canonical link functions for the "gaussian", "binomial", "poisson", "gamma" and "tweedie" ***family*** respectively. Canonical links usually work fine given that the data is appropriate for the selected combination of ***family*** and ***link_function***. Other combinations of ***family*** and ***link_function*** may or may not work (the model may fit poorly to the data if the wrong combination is used). Please note that the implementation of values other than "identity" is experimental.
20+
Determines how the linear predictor is transformed to predictions. Allowed values are "identity", "logit" and "log". For logistic regression use ***family***="binomial" and ***link_function***="logit". For a multiplicative model use the "log" ***link_function*** and a ***family*** that is not "binomial". The ***family*** "poisson", "gamma" or "tweedie" should only be used with the "log" ***link_function***. Invalid combinations of ***family*** and ***link_function*** may result in a warning message when fitting the model and/or a poor model fit. Please note that the implementation of values other than "identity" is experimental.
2121

2222
#### n_jobs (default = 0)
2323
Multi-threading parameter. If ***0*** then uses all available cores for multi-threading. Any other positive integer specifies the number of cores to use (***1*** means single-threading).

cpp/APLRRegressor.h

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class APLRRegressor
4646
bool abort_boosting;
4747
VectorXd linear_predictor_current;
4848
VectorXd linear_predictor_current_validation;
49+
double scaling_factor_for_log_link_function;
4950

5051
//Methods
5152
void validate_input_to_fit(const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight,const std::vector<std::string> &X_names, const std::vector<size_t> &validation_set_indexes);
@@ -81,11 +82,13 @@ class APLRRegressor
8182
VectorXd calculate_linear_predictor(const MatrixXd &X);
8283
void update_linear_predictor_and_predictors();
8384
void throw_error_if_response_contains_invalid_values(const VectorXd &y);
84-
void throw_error_if_response_is_not_between_0_and_1(const VectorXd &y);
85-
void throw_error_if_response_is_negative(const VectorXd &y);
86-
void throw_error_if_response_is_not_greater_than_zero(const VectorXd &y);
85+
void throw_error_if_response_is_not_between_0_and_1(const VectorXd &y,const std::string &error_message);
86+
void throw_error_if_response_is_negative(const VectorXd &y, const std::string &error_message);
87+
void throw_error_if_response_is_not_greater_than_zero(const VectorXd &y, const std::string &error_message);
8788
void throw_error_if_tweedie_power_is_invalid();
8889
VectorXd differentiate_predictions();
90+
void scale_training_observations_if_using_log_link_function();
91+
void revert_scaling_if_using_log_link_function();
8992

9093
public:
9194
//Fields
@@ -187,6 +190,7 @@ void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &samp
187190
update_coefficients_for_all_steps();
188191
print_final_summary();
189192
find_optimal_m_and_update_model_accordingly();
193+
revert_scaling_if_using_log_link_function();
190194
name_terms(X, X_names);
191195
calculate_feature_importance_on_validation_set();
192196
cleanup_after_fit();
@@ -218,17 +222,13 @@ void APLRRegressor::throw_error_if_link_function_does_not_exist()
218222
link_function_exists=true;
219223
else if(link_function=="log")
220224
link_function_exists=true;
221-
else if(link_function=="tweedie")
222-
link_function_exists=true;
223-
else if(link_function=="inverse")
224-
link_function_exists=true;
225225
if(!link_function_exists)
226226
throw std::runtime_error("Link function "+link_function+" is not available in APLR.");
227227
}
228228

229229
void APLRRegressor::throw_error_if_tweedie_power_is_invalid()
230230
{
231-
bool tweedie_power_equals_invalid_poits{check_if_approximately_equal(tweedie_power,1.0) || check_if_approximately_equal(tweedie_power,2.0)};
231+
bool tweedie_power_equals_invalid_poits{is_approximately_equal(tweedie_power,1.0) || is_approximately_equal(tweedie_power,2.0)};
232232
bool tweedie_power_is_in_invalid_range{std::isless(tweedie_power,1.0)};
233233
bool tweedie_power_is_invalid{tweedie_power_equals_invalid_poits || tweedie_power_is_in_invalid_range};
234234
if(tweedie_power_is_invalid)
@@ -262,34 +262,47 @@ void APLRRegressor::throw_error_if_validation_set_indexes_has_invalid_indexes(co
262262

263263
void APLRRegressor::throw_error_if_response_contains_invalid_values(const VectorXd &y)
264264
{
265-
if(link_function=="logit")
266-
throw_error_if_response_is_not_between_0_and_1(y);
267-
else if(link_function=="log" || (link_function=="tweedie" && std::isgreater(tweedie_power,1) && std::isless(tweedie_power,2)) )
268-
throw_error_if_response_is_negative(y);
269-
else if(link_function=="inverse" || (link_function=="tweedie" && std::isgreater(tweedie_power,2)) )
270-
throw_error_if_response_is_not_greater_than_zero(y);
265+
if(link_function=="logit" || family=="binomial")
266+
{
267+
std::string error_message{"Response values for the logit link function or binomial family cannot be less than zero or greater than one."};
268+
throw_error_if_response_is_not_between_0_and_1(y,error_message);
269+
}
270+
else if(family=="gamma" || (family=="tweedie" && std::isgreater(tweedie_power,2)) )
271+
{
272+
std::string error_message;
273+
if(family=="tweedie")
274+
error_message="Response values for the "+family+" family when tweedie_power>2 must be greater than zero.";
275+
else
276+
error_message="Response values for the "+family+" family must be greater than zero.";
277+
throw_error_if_response_is_not_greater_than_zero(y,error_message);
278+
}
279+
else if(link_function=="log" || family=="poisson" || (family=="tweedie" && std::isless(tweedie_power,2) && std::isgreater(tweedie_power,1)))
280+
{
281+
std::string error_message{"Response values for the log link function or poisson family or tweedie family when tweedie_power<2 cannot be less than zero."};
282+
throw_error_if_response_is_negative(y,error_message);
283+
}
271284
}
272285

273-
void APLRRegressor::throw_error_if_response_is_not_between_0_and_1(const VectorXd &y)
286+
void APLRRegressor::throw_error_if_response_is_not_between_0_and_1(const VectorXd &y, const std::string &error_message)
274287
{
275288
bool response_is_less_than_zero{(y.array()<0.0).any()};
276289
bool response_is_greater_than_one{(y.array()>1.0).any()};
277290
if(response_is_less_than_zero || response_is_greater_than_one)
278-
throw std::runtime_error("Response values for "+link_function+" link functions cannot be less than zero or greater than one.");
291+
throw std::runtime_error(error_message);
279292
}
280293

281-
void APLRRegressor::throw_error_if_response_is_negative(const VectorXd &y)
294+
void APLRRegressor::throw_error_if_response_is_negative(const VectorXd &y, const std::string &error_message)
282295
{
283296
bool response_is_less_than_zero{(y.array()<0.0).any()};
284297
if(response_is_less_than_zero)
285-
throw std::runtime_error("Response values for "+link_function+" link functions cannot be less than zero.");
298+
throw std::runtime_error(error_message);
286299
}
287300

288-
void APLRRegressor::throw_error_if_response_is_not_greater_than_zero(const VectorXd &y)
301+
void APLRRegressor::throw_error_if_response_is_not_greater_than_zero(const VectorXd &y, const std::string &error_message)
289302
{
290303
bool response_is_not_greater_than_zero{(y.array()<=0.0).any()};
291304
if(response_is_not_greater_than_zero)
292-
throw std::runtime_error("Response values for "+link_function+" link functions must be greater than zero.");
305+
throw std::runtime_error(error_message);
293306

294307
}
295308

@@ -363,6 +376,25 @@ void APLRRegressor::define_training_and_validation_sets(const MatrixXd &X,const
363376
sample_weight_validation[i]=sample_weight[validation_indexes[i]];
364377
}
365378
}
379+
380+
scale_training_observations_if_using_log_link_function();
381+
}
382+
383+
void APLRRegressor::scale_training_observations_if_using_log_link_function()
384+
{
385+
if(link_function=="log")
386+
{
387+
double inverse_scaling_factor{y_train.maxCoeff()/std::exp(1)};
388+
bool inverse_scaling_factor_is_not_zero{!is_approximately_zero(inverse_scaling_factor)};
389+
if(inverse_scaling_factor_is_not_zero)
390+
{
391+
scaling_factor_for_log_link_function=1/inverse_scaling_factor;
392+
y_train*=scaling_factor_for_log_link_function;
393+
y_validation*=scaling_factor_for_log_link_function;
394+
}
395+
else
396+
scaling_factor_for_log_link_function=1.0;
397+
}
366398
}
367399

368400
void APLRRegressor::initialize()
@@ -407,7 +439,7 @@ bool APLRRegressor::check_if_base_term_has_only_one_unique_value(size_t base_ter
407439
bool term_has_one_unique_value{true};
408440
for (size_t i = 1; i < rows; ++i)
409441
{
410-
bool observation_is_equal_to_previous{check_if_approximately_equal(X_train.col(base_term)[i], X_train.col(base_term)[i-1])};
442+
bool observation_is_equal_to_previous{is_approximately_equal(X_train.col(base_term)[i], X_train.col(base_term)[i-1])};
411443
if(!observation_is_equal_to_previous)
412444
{
413445
term_has_one_unique_value=false;
@@ -450,20 +482,7 @@ VectorXd APLRRegressor::differentiate_predictions()
450482
return 1.0/4.0 * (linear_predictor_current.array()/2.0).cosh().array().pow(-2);
451483
else if(link_function=="log")
452484
{
453-
double scaling{linear_predictor_current.maxCoeff()};
454-
return (linear_predictor_current.array()-scaling).array().exp();
455-
}
456-
else if(link_function=="tweedie")
457-
{
458-
VectorXd transformed_linear_predictor{transform_linear_predictor_to_negative(linear_predictor_current)};
459-
double scaling{std::pow((1-tweedie_power)*transformed_linear_predictor.mean(),-tweedie_power/(1-tweedie_power))};
460-
return scaling*((1-tweedie_power)*transformed_linear_predictor.array()).pow(tweedie_power/(1-tweedie_power));
461-
}
462-
else if(link_function=="inverse")
463-
{
464-
VectorXd transformed_linear_predictor{transform_linear_predictor_to_negative(linear_predictor_current)};
465-
double scaling{std::pow(transformed_linear_predictor.mean(),2)};
466-
return scaling * transformed_linear_predictor.array().pow(-2);
485+
return linear_predictor_current.array().exp();
467486
}
468487
return VectorXd(0);
469488
}
@@ -768,10 +787,7 @@ void APLRRegressor::select_the_best_term_and_update_errors(size_t boosting_step)
768787
if(validation_error_is_invalid)
769788
{
770789
abort_boosting=true;
771-
std::string warning_message{"Warning: Encountered numerical problems when calculating prediction errors in the previous boosting step. Not continuing with further boosting steps."};
772-
bool show_additional_warning{family=="poisson" || family=="tweedie" || family=="gamma" || (link_function!="identity" && link_function!="logit")};
773-
if(show_additional_warning)
774-
warning_message+=" For this combination of family and link_function, a reason may be too large or too small response values.";
790+
std::string warning_message{"Warning: Encountered numerical problems when calculating prediction errors in the previous boosting step. Not continuing with further boosting steps. One potential reason is if the combination of family and link_function is invalid."};
775791
std::cout<<warning_message<<"\n";
776792
}
777793
}
@@ -854,15 +870,15 @@ void APLRRegressor::update_coefficients_for_all_steps()
854870
//Filling down coefficient_steps for the intercept
855871
for (size_t j = 0; j < m; ++j) //For each boosting step
856872
{
857-
if(j>0 && check_if_approximately_zero(intercept_steps[j]) && !check_if_approximately_zero(intercept_steps[j-1]))
873+
if(j>0 && is_approximately_zero(intercept_steps[j]) && !is_approximately_zero(intercept_steps[j-1]))
858874
intercept_steps[j]=intercept_steps[j-1];
859875
}
860876
//Filling down coefficient_steps for each term in the model
861877
for (size_t i = 0; i < terms.size(); ++i) //For each term
862878
{
863879
for (size_t j = 0; j < m; ++j) //For each boosting step
864880
{
865-
if(j>0 && check_if_approximately_zero(terms[i].coefficient_steps[j]) && !check_if_approximately_zero(terms[i].coefficient_steps[j-1]))
881+
if(j>0 && is_approximately_zero(terms[i].coefficient_steps[j]) && !is_approximately_zero(terms[i].coefficient_steps[j-1]))
866882
terms[i].coefficient_steps[j]=terms[i].coefficient_steps[j-1];
867883
}
868884
}
@@ -893,12 +909,20 @@ void APLRRegressor::find_optimal_m_and_update_model_accordingly()
893909
terms_new.reserve(terms.size());
894910
for (size_t i = 0; i < terms.size(); ++i)
895911
{
896-
if(!check_if_approximately_zero(terms[i].coefficient))
912+
if(!is_approximately_zero(terms[i].coefficient))
897913
terms_new.push_back(terms[i]);
898914
}
899915
terms=std::move(terms_new);
900916
}
901917

918+
void APLRRegressor::revert_scaling_if_using_log_link_function()
919+
{
920+
if(link_function=="log")
921+
{
922+
intercept+=std::log(1/scaling_factor_for_log_link_function);
923+
}
924+
}
925+
902926
void APLRRegressor::name_terms(const MatrixXd &X, const std::vector<std::string> &X_names)
903927
{
904928
if(X_names.size()==0) //If nothing in X_names

cpp/constants.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#pragma once
22
#include <limits>
33

4-
const double NAN_DOUBLE{ std::numeric_limits<double>::quiet_NaN() };
5-
const double SMALL_NEGATIVE_VALUE{-0.000001};
4+
const double NAN_DOUBLE{ std::numeric_limits<double>::quiet_NaN() };

cpp/functions.h

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using namespace Eigen;
1515
//implements relative method - do not use for comparing with zero
1616
//use this most of the time, tolerance needs to be meaningful in your context
1717
template<typename TReal>
18-
static bool check_if_approximately_equal(TReal a, TReal b, TReal tolerance = std::numeric_limits<TReal>::epsilon())
18+
static bool is_approximately_equal(TReal a, TReal b, TReal tolerance = std::numeric_limits<TReal>::epsilon())
1919
{
2020
if(std::isinf(a) && std::isinf(b) && std::signbit(a)==std::signbit(b))
2121
return true;
@@ -33,7 +33,7 @@ static bool check_if_approximately_equal(TReal a, TReal b, TReal tolerance = std
3333
//supply tolerance that is meaningful in your context
3434
//for example, default tolerance may not work if you are comparing double with float
3535
template<typename TReal>
36-
static bool check_if_approximately_zero(TReal a, TReal tolerance = std::numeric_limits<TReal>::epsilon())
36+
static bool is_approximately_zero(TReal a, TReal tolerance = std::numeric_limits<TReal>::epsilon())
3737
{
3838
if (std::fabs(a) <= tolerance)
3939
return true;
@@ -146,18 +146,6 @@ double calculate_sum_error(const VectorXd &errors)
146146
return error;
147147
}
148148

149-
VectorXd transform_linear_predictor_to_negative(const VectorXd &linear_predictor)
150-
{
151-
VectorXd transformed_linear_predictor{linear_predictor};
152-
for (size_t i = 0; i < static_cast<size_t>(transformed_linear_predictor.rows()); ++i)
153-
{
154-
bool row_is_positive{std::isgreaterequal(transformed_linear_predictor[i],0.0)};
155-
if(row_is_positive)
156-
transformed_linear_predictor[i]=SMALL_NEGATIVE_VALUE;
157-
}
158-
return transformed_linear_predictor;
159-
}
160-
161149
VectorXd transform_linear_predictor_to_predictions(const VectorXd &linear_predictor, const std::string &link_function="identity", double tweedie_power=1.5)
162150
{
163151
if(link_function=="identity")
@@ -169,10 +157,6 @@ VectorXd transform_linear_predictor_to_predictions(const VectorXd &linear_predic
169157
}
170158
else if(link_function=="log")
171159
return linear_predictor.array().exp();
172-
else if(link_function=="tweedie")
173-
return (transform_linear_predictor_to_negative(linear_predictor).array() * (1-tweedie_power)).array().pow(1/(1-tweedie_power));
174-
else if(link_function=="inverse")
175-
return -1.0 / transform_linear_predictor_to_negative(linear_predictor).array();
176160
return VectorXd(0);
177161
}
178162

@@ -274,10 +258,10 @@ size_t calculate_max_index_in_vector(T &vector)
274258
}
275259

276260
template <typename T> //type must be an Eigen Matrix or Vector
277-
bool check_if_matrix_has_nan_or_infinite_elements(const T &x)
261+
bool matrix_has_nan_or_infinite_elements(const T &x)
278262
{
279-
bool matrix_has_nan_or_infinite_elements{!x.allFinite()};
280-
if(matrix_has_nan_or_infinite_elements)
263+
bool has_nan_or_infinite_elements{!x.allFinite()};
264+
if(has_nan_or_infinite_elements)
281265
return true;
282266
else
283267
return false;
@@ -289,8 +273,8 @@ void throw_error_if_matrix_has_nan_or_infinite_elements(const T &x, const std::s
289273
bool matrix_is_empty{x.size()==0};
290274
if(matrix_is_empty) return;
291275

292-
bool matrix_has_nan_or_infinite_elements{check_if_matrix_has_nan_or_infinite_elements(x)};
293-
if(matrix_has_nan_or_infinite_elements)
276+
bool has_nan_or_infinite_elements{matrix_has_nan_or_infinite_elements(x)};
277+
if(has_nan_or_infinite_elements)
294278
{
295279
throw std::runtime_error(matrix_name + " has nan or infinite elements.");
296280
}

cpp/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ int main()
4646
//Saving results
4747
save_data("output.csv",predictions);
4848
std::cout<<"min validation_error "<<model.validation_error_steps.minCoeff()<<"\n\n";
49-
std::cout<<check_if_approximately_equal(model.validation_error_steps.minCoeff(),7.02559,0.00001)<<"\n";
49+
std::cout<<is_approximately_equal(model.validation_error_steps.minCoeff(),7.02559,0.00001)<<"\n";
5050

5151
std::cout<<"mean prediction "<<predictions.mean()<<"\n\n";
52-
std::cout<<check_if_approximately_equal(predictions.mean(),23.9213,0.0001)<<"\n";
52+
std::cout<<is_approximately_equal(predictions.mean(),23.9213,0.0001)<<"\n";
5353

5454
std::cout<<"best_m: "<<model.m<<"\n";
5555

cpp/term.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Term::~Term()
115115
//Compare everything except given_terms
116116
bool Term::equals_not_comparing_given_terms(const Term &p1,const Term &p2)
117117
{
118-
bool split_point_and_direction{(check_if_approximately_equal(p1.split_point,p2.split_point) && p1.direction_right==p2.direction_right) || (std::isnan(p1.split_point) && std::isnan(p2.split_point))};
118+
bool split_point_and_direction{(is_approximately_equal(p1.split_point,p2.split_point) && p1.direction_right==p2.direction_right) || (std::isnan(p1.split_point) && std::isnan(p2.split_point))};
119119
bool base_term{p1.base_term==p2.base_term};
120120
return split_point_and_direction && base_term;
121121
}
@@ -191,7 +191,7 @@ void Term::calculate_given_terms_indices(const MatrixXd &X)
191191
VectorXd values_given_term{given_terms[j].calculate(X)};
192192
for (size_t i = 0; i < static_cast<size_t>(X.rows()); ++i) //for each row
193193
{
194-
if(check_if_approximately_zero(values_given_term[i])) //if zeroed out by given term
194+
if(is_approximately_zero(values_given_term[i])) //if zeroed out by given term
195195
{
196196
given_terms_indices.zeroed[count_zeroed]=i;
197197
++count_zeroed;
@@ -227,7 +227,7 @@ VectorXd Term::calculate(const MatrixXd &X)
227227
VectorXd values_given_term{given_terms[j].calculate(X)};
228228
for (size_t i = 0; i < static_cast<size_t>(values.size()); ++i) //for each row
229229
{
230-
if(check_if_approximately_zero(values_given_term[i]))
230+
if(is_approximately_zero(values_given_term[i]))
231231
values[i]=0;
232232
}
233233
}
@@ -365,7 +365,7 @@ void Term::setup_bins()
365365
potential_start_indexes.reserve(sorted_vectors.values_sorted.size());
366366
for (size_t i = start_row; i <= end_row; ++i)
367367
{
368-
bool is_eligible_start_index{i>0 && !check_if_approximately_equal(sorted_vectors.values_sorted[i],sorted_vectors.values_sorted[i-1])};
368+
bool is_eligible_start_index{i>0 && !is_approximately_equal(sorted_vectors.values_sorted[i],sorted_vectors.values_sorted[i-1])};
369369
if(is_eligible_start_index)
370370
potential_start_indexes.push_back(i);
371371
}

0 commit comments

Comments
 (0)