@@ -50,6 +50,7 @@ class APLRRegressor
5050 std::set<int > unique_groups_train;
5151 std::set<int > unique_groups_validation;
5252 std::vector<int > interaction_constraints;
53+ bool update_intercept_only_once;
5354
5455 void validate_input_to_fit (const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight,const std::vector<std::string> &X_names,
5556 const std::vector<size_t > &validation_set_indexes, const std::vector<size_t > &prioritized_predictors_indexes,
@@ -109,6 +110,7 @@ class APLRRegressor
109110 void revert_scaling_if_using_log_link_function ();
110111 void cap_predictions_to_minmax_in_training (VectorXd &predictions);
111112 std::string compute_raw_base_term_name (const Term &term, const std::string &X_name);
113+ void throw_error_if_m_is_invalid ();
112114
113115public:
114116 double intercept;
@@ -210,6 +212,7 @@ void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &samp
210212 throw_error_if_loss_function_does_not_exist ();
211213 throw_error_if_link_function_does_not_exist ();
212214 throw_error_if_dispersion_parameter_is_invalid ();
215+ throw_error_if_m_is_invalid ();
213216 validate_input_to_fit (X,y,sample_weight,X_names,validation_set_indexes,prioritized_predictors_indexes,monotonic_constraints,group,
214217 interaction_constraints);
215218 define_training_and_validation_sets (X,y,sample_weight,validation_set_indexes,group);
@@ -543,8 +546,21 @@ void APLRRegressor::initialize(const std::vector<size_t> &prioritized_predictors
543546 terms.clear ();
544547 terms.reserve (X_train.cols ()*reserved_terms_times_num_x);
545548
546- intercept=0 ;
547- intercept_steps=VectorXd::Constant (m,0 );
549+ if (loss_function == " group_mse" )
550+ {
551+ update_intercept_only_once = true ;
552+ if (sample_weight_train.size ()==0 )
553+ intercept = y_train.mean ();
554+ else
555+ intercept = (y_train.array ()*sample_weight_train.array ()).sum ()/sample_weight_train.array ().sum ();
556+ }
557+ else
558+ {
559+ update_intercept_only_once = false ;
560+ intercept=0 ;
561+ }
562+ intercept_steps=VectorXd::Constant (m, intercept);
563+
548564
549565 terms_eligible_current.reserve (X_train.cols ()*reserved_terms_times_num_x);
550566 size_t X_train_cols{static_cast <size_t >(X_train.cols ())};
@@ -710,7 +726,8 @@ void APLRRegressor::execute_boosting_steps()
710726
711727void APLRRegressor::execute_boosting_step (size_t boosting_step)
712728{
713- update_intercept (boosting_step);
729+ if (!update_intercept_only_once)
730+ update_intercept (boosting_step);
714731 bool prioritize_predictors{!abort_boosting && prioritized_predictors_indexes.size ()>0 };
715732 if (prioritize_predictors)
716733 {
@@ -1532,4 +1549,10 @@ std::string APLRRegressor::get_validation_tuning_metric()
15321549std::vector<size_t > APLRRegressor::get_validation_indexes ()
15331550{
15341551 return validation_indexes;
1552+ }
1553+
1554+ void APLRRegressor::throw_error_if_m_is_invalid ()
1555+ {
1556+ if (m<1 )
1557+ throw std::runtime_error (" The maximum number of boosting steps, m, must be at least 1." );
15351558}
0 commit comments