@@ -86,7 +86,6 @@ class APLRRegressor
8686 void name_terms (const MatrixXd &X, const std::vector<std::string> &X_names);
8787 void calculate_feature_importance_on_validation_set ();
8888 void find_min_and_max_training_predictions_or_responses ();
89- void calculate_validation_group_mse ();
9089 void cleanup_after_fit ();
9190 void validate_that_model_can_be_used (const MatrixXd &X);
9291 void throw_error_if_family_does_not_exist ();
@@ -134,16 +133,15 @@ class APLRRegressor
134133 double tweedie_power;
135134 double min_training_prediction_or_response;
136135 double max_training_prediction_or_response;
137- double validation_group_mse;
138- size_t group_size_for_validation_group_mse;
139136 std::vector<size_t > validation_indexes;
137+ std::string validation_tuning_metric;
140138
141139 // Methods
142140 APLRRegressor (size_t m=1000 ,double v=0.1 ,uint_fast32_t random_state=std::numeric_limits<uint_fast32_t >::lowest(),std::string family=" gaussian" ,
143141 std::string link_function=" identity" , size_t n_jobs=0 , double validation_ratio=0.2 ,double intercept=NAN_DOUBLE,
144142 size_t reserved_terms_times_num_x=100 , size_t bins=300 ,size_t verbosity=0 ,size_t max_interaction_level=1 ,size_t max_interactions=100000 ,
145143 size_t min_observations_in_split=20 , size_t ineligible_boosting_steps_added=10 , size_t max_eligible_terms=5 ,double tweedie_power=1.5 ,
146- size_t group_size_for_validation_group_mse= 100 );
144+ std::string validation_tuning_metric= " default " );
147145 APLRRegressor (const APLRRegressor &other);
148146 ~APLRRegressor ();
149147 void fit (const MatrixXd &X,const VectorXd &y,const VectorXd &sample_weight=VectorXd(0 ),const std::vector<std::string> &X_names={},const std::vector<size_t > &validation_set_indexes={},
@@ -161,22 +159,22 @@ class APLRRegressor
161159 double get_intercept ();
162160 VectorXd get_intercept_steps ();
163161 size_t get_m ();
164- double get_validation_group_mse ();
162+ std::string get_validation_tuning_metric ();
165163 std::vector<size_t > get_validation_indexes ();
166164};
167165
168166// Regular constructor
169167APLRRegressor::APLRRegressor (size_t m,double v,uint_fast32_t random_state,std::string family,std::string link_function,size_t n_jobs,
170168 double validation_ratio,double intercept,size_t reserved_terms_times_num_x,size_t bins,size_t verbosity,size_t max_interaction_level,
171169 size_t max_interactions,size_t min_observations_in_split,size_t ineligible_boosting_steps_added,size_t max_eligible_terms,double tweedie_power,
172- size_t group_size_for_validation_group_mse ):
170+ std::string validation_tuning_metric ):
173171 reserved_terms_times_num_x{reserved_terms_times_num_x},intercept{intercept},m{m},v{v},
174172 family{family},link_function{link_function},validation_ratio{validation_ratio},n_jobs{n_jobs},random_state{random_state},
175173 bins{bins},verbosity{verbosity},max_interaction_level{max_interaction_level},
176174 intercept_steps{VectorXd (0 )},max_interactions{max_interactions},interactions_eligible{0 },validation_error_steps{VectorXd (0 )},
177175 min_observations_in_split{min_observations_in_split},ineligible_boosting_steps_added{ineligible_boosting_steps_added},
178176 max_eligible_terms{max_eligible_terms},number_of_base_terms{0 },tweedie_power{tweedie_power},min_training_prediction_or_response{NAN_DOUBLE},
179- max_training_prediction_or_response{NAN_DOUBLE},validation_group_mse{NAN_DOUBLE},group_size_for_validation_group_mse{group_size_for_validation_group_mse },
177+ max_training_prediction_or_response{NAN_DOUBLE}, validation_tuning_metric{validation_tuning_metric },
180178 validation_indexes{std::vector<size_t >(0 )}
181179{
182180}
@@ -192,8 +190,8 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other):
192190 min_observations_in_split{other.min_observations_in_split },ineligible_boosting_steps_added{other.ineligible_boosting_steps_added },
193191 max_eligible_terms{other.max_eligible_terms },number_of_base_terms{other.number_of_base_terms },
194192 feature_importance{other.feature_importance },tweedie_power{other.tweedie_power },min_training_prediction_or_response{other.min_training_prediction_or_response },
195- max_training_prediction_or_response{other.max_training_prediction_or_response },validation_group_mse {other.validation_group_mse },
196- group_size_for_validation_group_mse{other. group_size_for_validation_group_mse }, validation_indexes{other.validation_indexes }
193+ max_training_prediction_or_response{other.max_training_prediction_or_response },validation_tuning_metric {other.validation_tuning_metric },
194+ validation_indexes{other.validation_indexes }
197195{
198196}
199197
@@ -225,7 +223,6 @@ void APLRRegressor::fit(const MatrixXd &X,const VectorXd &y,const VectorXd &samp
225223 name_terms (X, X_names);
226224 calculate_feature_importance_on_validation_set ();
227225 find_min_and_max_training_predictions_or_responses ();
228- calculate_validation_group_mse ();
229226 cleanup_after_fit ();
230227}
231228
@@ -447,7 +444,6 @@ void APLRRegressor::scale_training_observations_if_using_log_link_function()
447444 {
448445 scaling_factor_for_log_link_function=1 /inverse_scaling_factor;
449446 y_train*=scaling_factor_for_log_link_function;
450- y_validation*=scaling_factor_for_log_link_function;
451447 }
452448 else
453449 scaling_factor_for_log_link_function=1.0 ;
@@ -966,7 +962,50 @@ void APLRRegressor::add_new_term(size_t boosting_step)
966962
967963void APLRRegressor::calculate_and_validate_validation_error (size_t boosting_step)
968964{
969- validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
965+ VectorXd rescaled_predictions_current_validation (0 );
966+ bool link_function_is_log{link_function==" log" };
967+ if (link_function_is_log)
968+ {
969+ rescaled_predictions_current_validation = predictions_current_validation / scaling_factor_for_log_link_function;
970+ }
971+
972+ bool using_default{validation_tuning_metric==" default" };
973+ bool using_mse{validation_tuning_metric==" mse" };
974+ bool using_mae{validation_tuning_metric==" mae" };
975+ bool using_rankability{validation_tuning_metric==" rankability" };
976+ if (using_default)
977+ {
978+ if (link_function_is_log)
979+ validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,rescaled_predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
980+ else
981+ validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,predictions_current_validation,sample_weight_validation,family,tweedie_power),sample_weight_validation);
982+ }
983+ else if (using_mse)
984+ {
985+ if (link_function_is_log)
986+ validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,rescaled_predictions_current_validation,sample_weight_validation),sample_weight_validation);
987+ else
988+ validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,predictions_current_validation,sample_weight_validation),sample_weight_validation);
989+ }
990+ else if (using_mae)
991+ {
992+ if (link_function_is_log)
993+ validation_error_steps[boosting_step]=calculate_mean_error (calculate_absolute_errors (y_validation,rescaled_predictions_current_validation,sample_weight_validation),sample_weight_validation);
994+ else
995+ validation_error_steps[boosting_step]=calculate_mean_error (calculate_absolute_errors (y_validation,predictions_current_validation,sample_weight_validation),sample_weight_validation);
996+ }
997+ else if (using_rankability)
998+ {
999+ if (link_function_is_log)
1000+ validation_error_steps[boosting_step]=-calculate_rankability (y_validation,rescaled_predictions_current_validation,sample_weight_validation,random_state);
1001+ else
1002+ validation_error_steps[boosting_step]=-calculate_rankability (y_validation,predictions_current_validation,sample_weight_validation,random_state);
1003+ }
1004+ else
1005+ {
1006+ throw std::runtime_error (validation_tuning_metric + " is an invalid validation_tuning_metric." );
1007+ }
1008+
9701009 bool validation_error_is_invalid{std::isinf (validation_error_steps[boosting_step])};
9711010 if (validation_error_is_invalid)
9721011 {
@@ -1080,7 +1119,6 @@ void APLRRegressor::revert_scaling_if_using_log_link_function()
10801119 if (link_function==" log" )
10811120 {
10821121 y_train/=scaling_factor_for_log_link_function;
1083- y_validation/=scaling_factor_for_log_link_function;
10841122 intercept+=std::log (1 /scaling_factor_for_log_link_function);
10851123 for (size_t i = 0 ; i < static_cast <size_t >(intercept_steps.size ()); ++i)
10861124 {
@@ -1196,17 +1234,6 @@ void APLRRegressor::find_min_and_max_training_predictions_or_responses()
11961234 max_training_prediction_or_response=std::min (training_predictions.maxCoeff (), y_train.maxCoeff ());
11971235}
11981236
1199- void APLRRegressor::calculate_validation_group_mse ()
1200- {
1201- VectorXd validation_predictions{predict (X_validation,false )};
1202- VectorXi validation_predictions_sorted_index{sort_indexes_ascending (validation_predictions)};
1203- VectorXd y_validation_centered{calculate_rolling_centered_mean (y_validation,validation_predictions_sorted_index,group_size_for_validation_group_mse,sample_weight_validation)};
1204- VectorXd validation_predictions_centered{calculate_rolling_centered_mean (validation_predictions,validation_predictions_sorted_index,group_size_for_validation_group_mse,sample_weight_validation)};
1205-
1206- VectorXd squared_residuals{(y_validation_centered-validation_predictions_centered).array ().pow (2 )};
1207- validation_group_mse = squared_residuals.mean ();
1208- }
1209-
12101237void APLRRegressor::validate_that_model_can_be_used (const MatrixXd &X)
12111238{
12121239 if (std::isnan (intercept) || number_of_base_terms==0 ) throw std::runtime_error (" Model must be trained before predict() can be run." );
@@ -1354,9 +1381,9 @@ size_t APLRRegressor::get_m()
13541381 return m;
13551382}
13561383
1357- double APLRRegressor::get_validation_group_mse ()
1384+ std::string APLRRegressor::get_validation_tuning_metric ()
13581385{
1359- return validation_group_mse ;
1386+ return validation_tuning_metric ;
13601387}
13611388
13621389std::vector<size_t > APLRRegressor::get_validation_indexes ()
0 commit comments