@@ -50,7 +50,6 @@ class APLRRegressor
5050 std::set<int > unique_groups_train;
5151 std::set<int > unique_groups_validation;
5252 std::vector<int > interaction_constraints;
53- bool update_intercept_only_once;
5453
5554 void validate_input_to_fit (const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight,const std::vector<std::string> &X_names,
5655 const std::vector<size_t > &validation_set_indexes, const std::vector<size_t > &prioritized_predictors_indexes,
@@ -78,7 +77,7 @@ class APLRRegressor
7877 void add_necessary_given_terms_to_interaction (Term &interaction, Term &existing_model_term);
7978 void find_sorted_indexes_for_errors_for_interactions_to_consider ();
8079 void add_promising_interactions_and_select_the_best_one ();
81- void update_intercept (size_t boosting_step );
80+ void update_intercept ();
8281 void select_the_best_term_and_update_errors (size_t boosting_step, bool not_evaluating_prioritized_predictors=true );
8382 void update_terms (size_t boosting_step);
8483 void update_gradient_and_errors ();
@@ -128,7 +127,6 @@ class APLRRegressor
128127 std::vector<std::string> term_names;
129128 VectorXd term_coefficients;
130129 size_t max_interaction_level;
131- VectorXd intercept_steps;
132130 size_t max_interactions;
133131 size_t interactions_eligible;
134132 VectorXd validation_error_steps;
@@ -165,7 +163,6 @@ class APLRRegressor
165163 VectorXd get_validation_error_steps ();
166164 VectorXd get_feature_importance ();
167165 double get_intercept ();
168- VectorXd get_intercept_steps ();
169166 size_t get_optimal_m ();
170167 std::string get_validation_tuning_metric ();
171168 std::vector<size_t > get_validation_indexes ();
@@ -178,7 +175,7 @@ APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::s
178175 reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{intercept},m{m},v{v},
179176 loss_function{loss_function},link_function{link_function},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
180177 bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
181- intercept_steps{ VectorXd ( 0 )}, max_interactions{max_interactions},interactions_eligible{0 },validation_error_steps{VectorXd (0 )},
178+ max_interactions{max_interactions},interactions_eligible{0 },validation_error_steps{VectorXd (0 )},
182179 min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
183180 max_eligible_terms{max_eligible_terms},number_of_base_terms{0 },dispersion_parameter{dispersion_parameter},min_training_prediction_or_response{NAN_DOUBLE},
184181 max_training_prediction_or_response{NAN_DOUBLE}, validation_tuning_metric{validation_tuning_metric},
@@ -191,8 +188,8 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other):
191188 loss_function{other.loss_function },link_function{other.link_function },validation_ratio{other.validation_ratio },
192189 n_jobs{other.n_jobs },random_state{other.random_state },bins{other.bins },
193190 verbosity{other.verbosity },term_names{other.term_names },term_coefficients{other.term_coefficients },
194- max_interaction_level{other.max_interaction_level },intercept_steps {other.intercept_steps },
195- max_interactions{other. max_interactions }, interactions_eligible{other.interactions_eligible },validation_error_steps{other.validation_error_steps },
191+ max_interaction_level{other.max_interaction_level },max_interactions {other.max_interactions },
192+ interactions_eligible{other.interactions_eligible },validation_error_steps{other.validation_error_steps },
196193 min_observations_in_split{other.min_observations_in_split },ineligible_boosting_steps_added{other.ineligible_boosting_steps_added },
197194 max_eligible_terms{other.max_eligible_terms },number_of_base_terms{other.number_of_base_terms },
198195 feature_importance{other.feature_importance },dispersion_parameter{other.dispersion_parameter },min_training_prediction_or_response{other.min_training_prediction_or_response },
@@ -289,6 +286,12 @@ void APLRRegressor::throw_error_if_dispersion_parameter_is_invalid()
289286 }
290287}
291288
289+ void APLRRegressor::throw_error_if_m_is_invalid ()
290+ {
291+ if (m<1 )
292+ throw std::runtime_error (" The maximum number of boosting steps, m, must be at least 1." );
293+ }
294+
292295void APLRRegressor::validate_input_to_fit (const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight,
293296 const std::vector<std::string> &X_names, const std::vector<size_t > &validation_set_indexes,
294297 const std::vector<size_t > &prioritized_predictors_indexes, const std::vector<int > &monotonic_constraints, const VectorXi &group,
@@ -546,21 +549,7 @@ void APLRRegressor::initialize(const std::vector<size_t> &prioritized_predictors
546549 terms.clear ();
547550 terms.reserve (X_train.cols ()*reserved_terms_times_num_x);
548551
549- if (loss_function == " group_mse" )
550- {
551- update_intercept_only_once = true ;
552- if (sample_weight_train.size ()==0 )
553- intercept = y_train.mean ();
554- else
555- intercept = (y_train.array ()*sample_weight_train.array ()).sum ()/sample_weight_train.array ().sum ();
556- }
557- else
558- {
559- update_intercept_only_once = false ;
560- intercept=0 ;
561- }
562- intercept_steps=VectorXd::Constant (m, intercept);
563-
552+ double initial_prediction{0.0 };
564553
565554 terms_eligible_current.reserve (X_train.cols ()*reserved_terms_times_num_x);
566555 size_t X_train_cols{static_cast <size_t >(X_train.cols ())};
@@ -602,9 +591,9 @@ void APLRRegressor::initialize(const std::vector<size_t> &prioritized_predictors
602591 }
603592 }
604593
605- linear_predictor_current=VectorXd::Constant (y_train.size (),intercept );
594+ linear_predictor_current=VectorXd::Constant (y_train.size (),initial_prediction );
606595 linear_predictor_null_model=linear_predictor_current;
607- linear_predictor_current_validation=VectorXd::Constant (y_validation.size (),intercept );
596+ linear_predictor_current_validation=VectorXd::Constant (y_validation.size (),initial_prediction );
608597 predictions_current=transform_linear_predictor_to_predictions (linear_predictor_current,link_function);
609598 predictions_current_validation=transform_linear_predictor_to_predictions (linear_predictor_current_validation,link_function);
610599
@@ -717,17 +706,31 @@ VectorXd APLRRegressor::differentiate_predictions()
717706void APLRRegressor::execute_boosting_steps ()
718707{
719708 abort_boosting = false ;
720- for (size_t boosting_step = 0 ; boosting_step < m; ++boosting_step)
709+ update_intercept ();
710+ for (size_t boosting_step = 1 ; boosting_step < m; ++boosting_step)
721711 {
722712 execute_boosting_step (boosting_step);
723713 if (abort_boosting) break ;
724714 }
725715}
726716
717+ void APLRRegressor::update_intercept ()
718+ {
719+ double intercept_update;
720+ if (sample_weight_train.size ()==0 )
721+ intercept=neg_gradient_current.mean ();
722+ else
723+ intercept=(neg_gradient_current.array ()*sample_weight_train.array ()).sum ()/sample_weight_train.array ().sum ();
724+ linear_predictor_update=VectorXd::Constant (neg_gradient_current.size (),intercept);
725+ linear_predictor_update_validation=VectorXd::Constant (y_validation.size (),intercept);
726+ update_linear_predictor_and_predictions ();
727+ update_gradient_and_errors ();
728+ calculate_and_validate_validation_error (0 );
729+ print_summary_after_boosting_step (0 );
730+ }
731+
727732void APLRRegressor::execute_boosting_step (size_t boosting_step)
728733{
729- if (!update_intercept_only_once)
730- update_intercept (boosting_step);
731734 bool prioritize_predictors{!abort_boosting && prioritized_predictors_indexes.size ()>0 };
732735 if (prioritize_predictors)
733736 {
@@ -758,25 +761,6 @@ void APLRRegressor::execute_boosting_step(size_t boosting_step)
758761 print_summary_after_boosting_step (boosting_step);
759762}
760763
761- void APLRRegressor::update_intercept (size_t boosting_step)
762- {
763- double intercept_update;
764- if (sample_weight_train.size ()==0 )
765- intercept_update=v*neg_gradient_current.mean ();
766- else
767- intercept_update=v*(neg_gradient_current.array ()*sample_weight_train.array ()).sum ()/sample_weight_train.array ().sum ();
768- linear_predictor_update=VectorXd::Constant (neg_gradient_current.size (),intercept_update);
769- linear_predictor_update_validation=VectorXd::Constant (y_validation.size (),intercept_update);
770- update_linear_predictor_and_predictions ();
771- update_gradient_and_errors ();
772- calculate_and_validate_validation_error (boosting_step);
773- if (!abort_boosting)
774- {
775- intercept+=intercept_update;
776- intercept_steps[boosting_step]=intercept;
777- }
778- }
779-
780764void APLRRegressor::update_linear_predictor_and_predictions ()
781765{
782766 linear_predictor_current+=linear_predictor_update;
@@ -1219,13 +1203,6 @@ void APLRRegressor::print_summary_after_boosting_step(size_t boosting_step)
12191203
12201204void APLRRegressor::update_coefficients_for_all_steps ()
12211205{
1222- for (size_t j = 0 ; j < m; ++j)
1223- {
1224- bool fill_down_coefficient_steps{j>0 && is_approximately_zero (intercept_steps[j]) && !is_approximately_zero (intercept_steps[j-1 ])};
1225- if (fill_down_coefficient_steps)
1226- intercept_steps[j]=intercept_steps[j-1 ];
1227- }
1228-
12291206 for (size_t i = 0 ; i < terms.size (); ++i)
12301207 {
12311208 for (size_t j = 0 ; j < m; ++j)
@@ -1249,7 +1226,6 @@ void APLRRegressor::find_optimal_m_and_update_model_accordingly()
12491226{
12501227 size_t best_boosting_step_index;
12511228 validation_error_steps.minCoeff (&best_boosting_step_index);
1252- intercept=intercept_steps[best_boosting_step_index];
12531229 for (size_t i = 0 ; i < terms.size (); ++i)
12541230 {
12551231 terms[i].coefficient = terms[i].coefficient_steps [best_boosting_step_index];
@@ -1274,10 +1250,6 @@ void APLRRegressor::revert_scaling_if_using_log_link_function()
12741250 {
12751251 y_train/=scaling_factor_for_log_link_function;
12761252 intercept+=std::log (1 /scaling_factor_for_log_link_function);
1277- for (Eigen::Index i = 0 ; i < intercept_steps.size (); ++i)
1278- {
1279- intercept_steps[i]+=std::log (1 /scaling_factor_for_log_link_function);
1280- }
12811253 }
12821254}
12831255
@@ -1451,13 +1423,13 @@ VectorXd APLRRegressor::predict(const MatrixXd &X, bool cap_predictions_to_minma
14511423
14521424VectorXd APLRRegressor::calculate_linear_predictor (const MatrixXd &X)
14531425{
1454- VectorXd predictions {VectorXd::Constant (X.rows (),intercept)};
1426+ VectorXd linear_predictor {VectorXd::Constant (X.rows (),intercept)};
14551427 for (size_t i = 0 ; i < terms.size (); ++i)
14561428 {
14571429 VectorXd contrib{terms[i].calculate_contribution_to_linear_predictor (X)};
1458- predictions +=contrib;
1430+ linear_predictor +=contrib;
14591431 }
1460- return predictions ;
1432+ return linear_predictor ;
14611433}
14621434
14631435void APLRRegressor::cap_predictions_to_minmax_in_training (VectorXd &predictions)
@@ -1531,11 +1503,6 @@ double APLRRegressor::get_intercept()
15311503 return intercept;
15321504}
15331505
1534- VectorXd APLRRegressor::get_intercept_steps ()
1535- {
1536- return intercept_steps;
1537- }
1538-
15391506size_t APLRRegressor::get_optimal_m ()
15401507{
15411508 return m_optimal;
@@ -1549,10 +1516,4 @@ std::string APLRRegressor::get_validation_tuning_metric()
15491516std::vector<size_t > APLRRegressor::get_validation_indexes ()
15501517{
15511518 return validation_indexes;
1552- }
1553-
1554- void APLRRegressor::throw_error_if_m_is_invalid ()
1555- {
1556- if (m<1 )
1557- throw std::runtime_error (" The maximum number of boosting steps, m, must be at least 1." );
15581519}
0 commit comments