Skip to content

Commit fe836a0

Browse files
refactor
1 parent 9a8290e commit fe836a0

File tree

4 files changed

+55
-15
lines changed

4 files changed

+55
-15
lines changed

cpp/APLRRegressor.h

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class APLRRegressor
8888
void prune_terms(size_t boosting_step);
8989
void update_coefficient_steps(size_t boosting_step);
9090
void calculate_and_validate_validation_error(size_t boosting_step);
91-
void calculate_validation_error(size_t boosting_step, const VectorXd &predictions);
91+
double calculate_validation_error(const VectorXd &predictions);
9292
void update_term_eligibility();
9393
void print_summary_after_boosting_step(size_t boosting_step);
9494
void print_final_summary();
@@ -99,6 +99,7 @@ class APLRRegressor
9999
void calculate_feature_importance_on_validation_set();
100100
void find_min_and_max_training_predictions_or_responses();
101101
void cleanup_after_fit();
102+
void check_term_integrity();
102103
void validate_that_model_can_be_used(const MatrixXd &X);
103104
void throw_error_if_loss_function_does_not_exist();
104105
void throw_error_if_link_function_does_not_exist();
@@ -272,6 +273,7 @@ void APLRRegressor::fit(const MatrixXd &X, const VectorXd &y, const VectorXd &sa
272273
calculate_feature_importance_on_validation_set();
273274
find_min_and_max_training_predictions_or_responses();
274275
cleanup_after_fit();
276+
check_term_integrity();
275277
}
276278

277279
void APLRRegressor::throw_error_if_loss_function_does_not_exist()
@@ -1316,7 +1318,7 @@ void APLRRegressor::update_coefficient_steps(size_t boosting_step)
13161318

13171319
void APLRRegressor::calculate_and_validate_validation_error(size_t boosting_step)
13181320
{
1319-
calculate_validation_error(boosting_step, predictions_current_validation);
1321+
validation_error_steps[boosting_step] = calculate_validation_error(predictions_current_validation);
13201322
bool validation_error_is_invalid{!std::isfinite(validation_error_steps[boosting_step])};
13211323
if (validation_error_is_invalid)
13221324
{
@@ -1326,15 +1328,15 @@ void APLRRegressor::calculate_and_validate_validation_error(size_t boosting_step
13261328
}
13271329
}
13281330

1329-
void APLRRegressor::calculate_validation_error(size_t boosting_step, const VectorXd &predictions)
1331+
double APLRRegressor::calculate_validation_error(const VectorXd &predictions)
13301332
{
13311333
if (validation_tuning_metric == "default")
13321334
{
13331335
if (loss_function == "custom_function")
13341336
{
13351337
try
13361338
{
1337-
validation_error_steps[boosting_step] = calculate_custom_loss_function(y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
1339+
return calculate_custom_loss_function(y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
13381340
}
13391341
catch (const std::exception &e)
13401342
{
@@ -1343,28 +1345,28 @@ void APLRRegressor::calculate_validation_error(size_t boosting_step, const Vecto
13431345
}
13441346
}
13451347
else
1346-
validation_error_steps[boosting_step] = calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, loss_function, dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
1348+
return calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, loss_function, dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
13471349
}
13481350
else if (validation_tuning_metric == "mse")
1349-
validation_error_steps[boosting_step] = calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, MSE_LOSS_FUNCTION), sample_weight_validation);
1351+
return calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, MSE_LOSS_FUNCTION), sample_weight_validation);
13501352
else if (validation_tuning_metric == "mae")
1351-
validation_error_steps[boosting_step] = calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, "mae"), sample_weight_validation);
1353+
return calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, "mae"), sample_weight_validation);
13521354
else if (validation_tuning_metric == "negative_gini")
1353-
validation_error_steps[boosting_step] = -calculate_gini(y_validation, predictions, sample_weight_validation);
1355+
return -calculate_gini(y_validation, predictions, sample_weight_validation);
13541356
else if (validation_tuning_metric == "rankability")
1355-
validation_error_steps[boosting_step] = -calculate_rankability(y_validation, predictions, sample_weight_validation, random_state);
1357+
return -calculate_rankability(y_validation, predictions, sample_weight_validation, random_state);
13561358
else if (validation_tuning_metric == "group_mse")
13571359
{
13581360
bool group_is_not_provided{group_validation.rows() == 0};
13591361
if (group_is_not_provided)
13601362
throw std::runtime_error("When validation_tuning_metric is group_mse then the group argument in fit() must be provided.");
1361-
validation_error_steps[boosting_step] = calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, "group_mse", dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
1363+
return calculate_mean_error(calculate_errors(y_validation, predictions, sample_weight_validation, "group_mse", dispersion_parameter, group_validation, unique_groups_validation, quantile), sample_weight_validation);
13621364
}
13631365
else if (validation_tuning_metric == "custom_function")
13641366
{
13651367
try
13661368
{
1367-
validation_error_steps[boosting_step] = calculate_custom_validation_error_function(y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
1369+
return calculate_custom_validation_error_function(y_validation, predictions, sample_weight_validation, group_validation, other_data_validation);
13681370
}
13691371
catch (const std::exception &e)
13701372
{
@@ -1667,6 +1669,37 @@ void APLRRegressor::cleanup_after_fit()
16671669
other_data_validation.resize(0, 0);
16681670
}
16691671

1672+
void APLRRegressor::check_term_integrity()
1673+
{
1674+
for (auto &term : terms)
1675+
{
1676+
for (auto &given_term : term.given_terms)
1677+
{
1678+
bool same_base_term{term.base_term == given_term.base_term};
1679+
if (same_base_term)
1680+
{
1681+
bool given_term_has_no_split_point{!std::isfinite(given_term.split_point)};
1682+
bool given_term_has_the_same_direction_right{term.direction_right == given_term.direction_right};
1683+
bool given_term_has_incorrect_split_point;
1684+
if (term.direction_right)
1685+
{
1686+
given_term_has_incorrect_split_point = std::islessequal(given_term.split_point, term.split_point);
1687+
}
1688+
else
1689+
{
1690+
given_term_has_incorrect_split_point = std::isgreaterequal(given_term.split_point, term.split_point);
1691+
}
1692+
if (given_term_has_no_split_point)
1693+
throw std::runtime_error("Bug: Interaction in term " + term.name + " has no split point.");
1694+
if (given_term_has_the_same_direction_right)
1695+
throw std::runtime_error("Bug: Interaction in term " + term.name + " has an incorrect direction_right.");
1696+
if (given_term_has_incorrect_split_point)
1697+
throw std::runtime_error("Bug: Interaction in term " + term.name + " has an incorrect split_point.");
1698+
}
1699+
}
1700+
}
1701+
}
1702+
16701703
VectorXd APLRRegressor::predict(const MatrixXd &X, bool cap_predictions_to_minmax_in_training)
16711704
{
16721705
validate_that_model_can_be_used(X);

cpp/term.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,9 @@ void Term::prune_given_terms()
652652
for (size_t i = 0; i < given_terms.size(); ++i)
653653
{
654654
bool keep_given_term{true};
655-
bool removing_given_term_with_same_base_term_and_direction{base_term == given_terms[i].base_term && direction_right == given_terms[i].direction_right};
656-
bool removing_linear_given_term{base_term == given_terms[i].base_term && std::isfinite(split_point) && !std::isfinite(given_terms[i].split_point)};
655+
bool base_term_is_equal{base_term == given_terms[i].base_term};
656+
bool removing_given_term_with_same_base_term_and_direction{base_term_is_equal && direction_right == given_terms[i].direction_right};
657+
bool removing_linear_given_term{base_term_is_equal && !std::isfinite(given_terms[i].split_point)};
657658
if (removing_given_term_with_same_base_term_and_direction)
658659
{
659660
keep_given_term = false;

examples/train_aplr_regression_cross_validation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@
9191
# Coefficient shape for the first predictor. Will be empty if the first predictor is not used as a main effect in the model
9292
coefficient_shape = best_model.get_coefficient_shape_function(predictor_index=0)
9393
coefficient_shape = pd.DataFrame(
94-
{"predictor_value": coefficient_shape.keys(), "coefficient": coefficient_shape.values()}
94+
{
95+
"predictor_value": coefficient_shape.keys(),
96+
"coefficient": coefficient_shape.values(),
97+
}
9598
)
9699

97100

examples/train_aplr_regression_validation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@
102102
# Coefficient shape for the third predictor. Will be empty if the third predictor is not used as a main effect in the model
103103
coefficient_shape = best_model.get_coefficient_shape_function(predictor_index=2)
104104
coefficient_shape = pd.DataFrame(
105-
{"predictor_value": coefficient_shape.keys(), "coefficient": coefficient_shape.values()}
105+
{
106+
"predictor_value": coefficient_shape.keys(),
107+
"coefficient": coefficient_shape.values(),
108+
}
106109
)
107110

108111

0 commit comments

Comments
 (0)