@@ -88,7 +88,7 @@ class APLRRegressor
8888 void prune_terms (size_t boosting_step);
8989 void update_coefficient_steps (size_t boosting_step);
9090 void calculate_and_validate_validation_error (size_t boosting_step);
91- void calculate_validation_error (size_t boosting_step, const VectorXd &predictions);
91+ double calculate_validation_error (const VectorXd &predictions);
9292 void update_term_eligibility ();
9393 void print_summary_after_boosting_step (size_t boosting_step);
9494 void print_final_summary ();
@@ -99,6 +99,7 @@ class APLRRegressor
9999 void calculate_feature_importance_on_validation_set ();
100100 void find_min_and_max_training_predictions_or_responses ();
101101 void cleanup_after_fit ();
102+ void check_term_integrity ();
102103 void validate_that_model_can_be_used (const MatrixXd &X);
103104 void throw_error_if_loss_function_does_not_exist ();
104105 void throw_error_if_link_function_does_not_exist ();
@@ -272,6 +273,7 @@ void APLRRegressor::fit(const MatrixXd &X, const VectorXd &y, const VectorXd &sa
272273 calculate_feature_importance_on_validation_set ();
273274 find_min_and_max_training_predictions_or_responses ();
274275 cleanup_after_fit ();
276+ check_term_integrity ();
275277}
276278
277279void APLRRegressor::throw_error_if_loss_function_does_not_exist ()
@@ -1316,7 +1318,7 @@ void APLRRegressor::update_coefficient_steps(size_t boosting_step)
13161318
13171319void APLRRegressor::calculate_and_validate_validation_error (size_t boosting_step)
13181320{
1319- calculate_validation_error ( boosting_step, predictions_current_validation);
1321+ validation_error_steps[ boosting_step] = calculate_validation_error ( predictions_current_validation);
13201322 bool validation_error_is_invalid{!std::isfinite (validation_error_steps[boosting_step])};
13211323 if (validation_error_is_invalid)
13221324 {
@@ -1326,15 +1328,15 @@ void APLRRegressor::calculate_and_validate_validation_error(size_t boosting_step
13261328 }
13271329}
13281330
1329- void APLRRegressor::calculate_validation_error (size_t boosting_step, const VectorXd &predictions)
1331+ double APLRRegressor::calculate_validation_error (const VectorXd &predictions)
13301332{
13311333 if (validation_tuning_metric == " default" )
13321334 {
13331335 if (loss_function == " custom_function" )
13341336 {
13351337 try
13361338 {
1337- validation_error_steps[boosting_step] = calculate_custom_loss_function (y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
1339+ return calculate_custom_loss_function (y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
13381340 }
13391341 catch (const std::exception &e)
13401342 {
@@ -1343,28 +1345,28 @@ void APLRRegressor::calculate_validation_error(size_t boosting_step, const Vecto
13431345 }
13441346 }
13451347 else
1346- validation_error_steps[boosting_step] = calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, loss_function, dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
1348+ return calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, loss_function, dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
13471349 }
13481350 else if (validation_tuning_metric == " mse" )
1349- validation_error_steps[boosting_step] = calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, MSE_LOSS_FUNCTION), sample_weight_validation);
1351+ return calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, MSE_LOSS_FUNCTION), sample_weight_validation);
13501352 else if (validation_tuning_metric == " mae" )
1351- validation_error_steps[boosting_step] = calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, " mae" ), sample_weight_validation);
1353+ return calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, " mae" ), sample_weight_validation);
13521354 else if (validation_tuning_metric == " negative_gini" )
1353- validation_error_steps[boosting_step] = -calculate_gini (y_validation, predictions, sample_weight_validation);
1355+ return -calculate_gini (y_validation, predictions, sample_weight_validation);
13541356 else if (validation_tuning_metric == " rankability" )
1355- validation_error_steps[boosting_step] = -calculate_rankability (y_validation, predictions, sample_weight_validation, random_state);
1357+ return -calculate_rankability (y_validation, predictions, sample_weight_validation, random_state);
13561358 else if (validation_tuning_metric == " group_mse" )
13571359 {
13581360 bool group_is_not_provided{group_validation.rows () == 0 };
13591361 if (group_is_not_provided)
13601362 throw std::runtime_error (" When validation_tuning_metric is group_mse then the group argument in fit() must be provided." );
1361- validation_error_steps[boosting_step] = calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, " group_mse" , dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
1363+ return calculate_mean_error (calculate_errors (y_validation, predictions, sample_weight_validation, " group_mse" , dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
13621364 }
13631365 else if (validation_tuning_metric == " custom_function" )
13641366 {
13651367 try
13661368 {
1367- validation_error_steps[boosting_step] = calculate_custom_validation_error_function (y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
1369+ return calculate_custom_validation_error_function (y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
13681370 }
13691371 catch (const std::exception &e)
13701372 {
@@ -1667,6 +1669,37 @@ void APLRRegressor::cleanup_after_fit()
16671669 other_data_validation.resize (0 , 0 );
16681670}
16691671
1672+ void APLRRegressor::check_term_integrity ()
1673+ {
1674+ for (auto &term : terms)
1675+ {
1676+ for (auto &given_term : term.given_terms )
1677+ {
1678+ bool same_base_term{term.base_term == given_term.base_term };
1679+ if (same_base_term)
1680+ {
1681+ bool given_term_has_no_split_point{!std::isfinite (given_term.split_point )};
1682+ bool given_term_has_the_same_direction_right{term.direction_right == given_term.direction_right };
1683+ bool given_term_has_incorrect_split_point;
1684+ if (term.direction_right )
1685+ {
1686+ given_term_has_incorrect_split_point = std::islessequal (given_term.split_point , term.split_point );
1687+ }
1688+ else
1689+ {
1690+ given_term_has_incorrect_split_point = std::isgreaterequal (given_term.split_point , term.split_point );
1691+ }
1692+ if (given_term_has_no_split_point)
1693+ throw std::runtime_error (" Bug: Interaction in term " + term.name + " has no split point." );
1694+ if (given_term_has_the_same_direction_right)
1695+ throw std::runtime_error (" Bug: Interaction in term " + term.name + " has an incorrect direction_right." );
1696+ if (given_term_has_incorrect_split_point)
1697+ throw std::runtime_error (" Bug: Interaction in term " + term.name + " has an incorrect split_point." );
1698+ }
1699+ }
1700+ }
1701+ }
1702+
16701703VectorXd APLRRegressor::predict (const MatrixXd &X, bool cap_predictions_to_minmax_in_training)
16711704{
16721705 validate_that_model_can_be_used (X);
0 commit comments