@@ -136,6 +136,7 @@ class APLRRegressor
136136 double max_training_prediction_or_response;
137137 double validation_group_mse;
138138 size_t group_size_for_validation_group_mse;
139+ std::vector<size_t > validation_indexes;
139140
140141 // Methods
141142 APLRRegressor (size_t m=1000 ,double v=0.1 ,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t >::lowest(),std::string family=" gaussian" ,
@@ -161,6 +162,7 @@ class APLRRegressor
161162 VectorXd get_intercept_steps ();
162163 size_t get_m ();
163164 double get_validation_group_mse ();
165+ std::vector<size_t > get_validation_indexes ();
164166};
165167
166168// Regular constructor
@@ -174,7 +176,8 @@ APLRRegressor::APLRRegressor(size_t m,double v,uint_fast32_t random_state,std::s
174176 intercept_steps{VectorXd (0 )},max_interactions{max_interactions},interactions_eligible{0 },validation_error_steps{VectorXd (0 )},
175177 min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
176178 max_eligible_terms{max_eligible_terms},number_of_base_terms{0 },tweedie_power{tweedie_power},min_training_prediction_or_response{NAN_DOUBLE},
177- max_training_prediction_or_response{NAN_DOUBLE},validation_group_mse{NAN_DOUBLE},group_size_for_validation_group_mse{group_size_for_validation_group_mse}
179+ max_training_prediction_or_response{NAN_DOUBLE},validation_group_mse{NAN_DOUBLE},group_size_for_validation_group_mse{group_size_for_validation_group_mse},
180+ validation_indexes{std::vector<size_t >(0 )}
178181{
179182}
180183
@@ -190,7 +193,7 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other):
190193 max_eligible_terms{other.max_eligible_terms },number_of_base_terms{other.number_of_base_terms },
191194 feature_importance{other.feature_importance },tweedie_power{other.tweedie_power },min_training_prediction_or_response{other.min_training_prediction_or_response },
192195 max_training_prediction_or_response{other.max_training_prediction_or_response },validation_group_mse{other.validation_group_mse },
193- group_size_for_validation_group_mse{other.group_size_for_validation_group_mse }
196+ group_size_for_validation_group_mse{other.group_size_for_validation_group_mse },validation_indexes{other. validation_indexes }
194197{
195198}
196199
@@ -363,20 +366,20 @@ void APLRRegressor::define_training_and_validation_sets(const MatrixXd &X,const
363366{
364367 size_t y_size{static_cast <size_t >(y.size ())};
365368 std::vector<size_t > train_indexes;
366- std::vector<size_t > validation_indexes;
367369 bool use_validation_set_indexes{validation_set_indexes.size ()>0 };
368370 if (use_validation_set_indexes)
369371 {
370372 std::vector<size_t > all_indexes (y_size);
371373 std::iota (std::begin (all_indexes),std::end (all_indexes),0 );
372374 validation_indexes=validation_set_indexes;
373375 train_indexes.reserve (y_size-validation_indexes.size ());
374- std::remove_copy_if (all_indexes.begin (),all_indexes.end (),std::back_inserter (train_indexes),[&validation_indexes ](const size_t &arg)
376+ std::remove_copy_if (all_indexes.begin (),all_indexes.end (),std::back_inserter (train_indexes),[this ](const size_t &arg)
375377 { return (std::find (validation_indexes.begin (),validation_indexes.end (),arg) != validation_indexes.end ());});
376378 }
377379 else
378380 {
379381 train_indexes.reserve (y_size);
382+ validation_indexes = std::vector<size_t >(0 );
380383 validation_indexes.reserve (y_size);
381384 std::mt19937 mersenne{random_state};
382385 std::uniform_real_distribution<double > distribution (0.0 ,1.0 );
@@ -1354,4 +1357,9 @@ size_t APLRRegressor::get_m()
13541357double APLRRegressor::get_validation_group_mse ()
13551358{
13561359 return validation_group_mse;
1360+ }
1361+
1362+ std::vector<size_t > APLRRegressor::get_validation_indexes ()
1363+ {
1364+ return validation_indexes;
13571365}
0 commit comments