Skip to content

Commit 09609b3

Browse files
bugfix
1 parent 783c70e commit 09609b3

File tree

5 files changed

+43
-23
lines changed

5 files changed

+43
-23
lines changed

cpp/APLRRegressor.h

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1371,6 +1371,9 @@ void APLRRegressor::determine_interactions_to_consider(const std::vector<size_t>
13711371
if (model_term_without_given_terms_can_be_a_given_term)
13721372
model_term_with_added_given_term.given_terms.push_back(model_term_without_given_terms);
13731373
add_necessary_given_terms_to_interaction(interaction, model_term_with_added_given_term);
1374+
bool interaction_only_uses_one_base_term{interaction.term_uses_just_these_predictors({interaction.base_term})};
1375+
if (interaction_only_uses_one_base_term)
1376+
continue;
13741377
if (interaction_constraints_provided)
13751378
{
13761379
bool interaction_violates_constraints{true};
@@ -2457,7 +2460,7 @@ std::map<double, double> APLRRegressor::get_main_effect_shape(size_t predictor_i
24572460
return main_effect_shape;
24582461

24592462
std::vector<double> split_points;
2460-
size_t max_potential_split_points{relevant_term_indexes.size() * 3 + 2};
2463+
size_t max_potential_split_points{(relevant_term_indexes.size() * 3 + 2) * 3};
24612464
split_points.reserve(max_potential_split_points);
24622465
for (auto &relevant_term_index : relevant_term_indexes)
24632466
{
@@ -2478,6 +2481,22 @@ std::map<double, double> APLRRegressor::get_main_effect_shape(size_t predictor_i
24782481
split_points.push_back(min_predictor_values_in_training[predictor_index]);
24792482
split_points.push_back(max_predictor_values_in_training[predictor_index]);
24802483
split_points = remove_duplicate_elements_from_vector(split_points);
2484+
2485+
VectorXd split_point_increments{VectorXd(split_points.size() - 1)};
2486+
for (Eigen::Index i = 0; i < split_point_increments.size(); ++i)
2487+
{
2488+
split_point_increments[i] = split_points[i + 1] - split_points[i];
2489+
}
2490+
double minimum_split_point_increment{split_point_increments.minCoeff()};
2491+
double increment_around_split_points{minimum_split_point_increment / DIVISOR_IN_GET_MAIN_EFFECT_SHAPE_FUNCTION};
2492+
2493+
size_t num_split_points_before_small_increments{split_points.size()};
2494+
for (size_t i = 0; i < num_split_points_before_small_increments; ++i)
2495+
{
2496+
split_points.push_back(split_points[i] - increment_around_split_points);
2497+
split_points.push_back(split_points[i] + increment_around_split_points);
2498+
}
2499+
split_points = remove_duplicate_elements_from_vector(split_points);
24812500
split_points.shrink_to_fit();
24822501

24832502
MatrixXd X{MatrixXd::Constant(split_points.size(), number_of_base_terms, 0)};

cpp/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ const double NAN_DOUBLE{std::numeric_limits<double>::quiet_NaN()};
55
const int MAX_ABS_EXPONENT_TO_APPLY_ON_LINEAR_PREDICTOR_IN_LOGIT_MODEL{std::min(16, std::numeric_limits<double>::max_exponent10)};
66
const std::string MSE_LOSS_FUNCTION{"mse"};
77
const size_t MIN_CATEGORIES_IN_CLASSIFIER{2};
8+
const double DIVISOR_IN_GET_MAIN_EFFECT_SHAPE_FUNCTION{1000.0};
89
const Eigen::Index MIN_OBSERATIONS_IN_A_CV_FOLD{2};

cpp/tests.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class Tests
167167
save_as_csv_file("data/output.csv", predictions);
168168

169169
std::cout << predictions.mean() << "\n\n";
170-
tests.push_back(is_approximately_equal(predictions.mean(), 18.534016846656947));
170+
tests.push_back(is_approximately_equal(predictions.mean(), 19.067710451454566));
171171
}
172172

173173
void test_aplrregressor_cauchy_penalties()
@@ -221,7 +221,7 @@ class Tests
221221
save_as_csv_file("data/output.csv", predictions);
222222

223223
std::cout << predictions.mean() << "\n\n";
224-
tests.push_back(is_approximately_equal(predictions.mean(), 20.146282076477394));
224+
tests.push_back(is_approximately_equal(predictions.mean(), 20.809163574542939));
225225
}
226226

227227
void test_aplrregressor_cauchy_linear_effects_only_first()
@@ -275,7 +275,7 @@ class Tests
275275
save_as_csv_file("data/output.csv", predictions);
276276

277277
std::cout << predictions.mean() << "\n\n";
278-
tests.push_back(is_approximately_equal(predictions.mean(), 17.964887018234787));
278+
tests.push_back(is_approximately_equal(predictions.mean(), 17.965154984786622));
279279
}
280280

281281
void test_aplrregressor_cauchy_group_mse_validation()
@@ -466,7 +466,7 @@ class Tests
466466
save_as_csv_file("data/output.csv", predictions);
467467

468468
std::cout << predictions.mean() << "\n\n";
469-
tests.push_back(is_approximately_equal(predictions.mean(), 20.873594934501561));
469+
tests.push_back(is_approximately_equal(predictions.mean(), 20.979930894644177));
470470
}
471471

472472
void test_aplrregressor_custom_loss_and_validation()
@@ -526,7 +526,7 @@ class Tests
526526
save_as_csv_file("data/output.csv", predictions);
527527

528528
std::cout << predictions.mean() << "\n\n";
529-
tests.push_back(is_approximately_equal(predictions.mean(), 23.91507568241019));
529+
tests.push_back(is_approximately_equal(predictions.mean(), 23.87336747209412));
530530
}
531531

532532
void test_aplrregressor_custom_loss()
@@ -585,7 +585,7 @@ class Tests
585585
save_as_csv_file("data/output.csv", predictions);
586586

587587
std::cout << predictions.mean() << "\n\n";
588-
tests.push_back(is_approximately_equal(predictions.mean(), 23.703500296203778, 0.00001));
588+
tests.push_back(is_approximately_equal(predictions.mean(), 24.301339246925711, 0.00001));
589589
}
590590

591591
void test_aplrregressor_gamma_custom_link()
@@ -913,7 +913,7 @@ class Tests
913913
save_as_csv_file("data/output.csv", predictions);
914914

915915
std::cout << predictions.mean() << "\n\n";
916-
tests.push_back(is_approximately_equal(predictions.mean(), 20.82771158964184));
916+
tests.push_back(is_approximately_equal(predictions.mean(), 20.849747430496922));
917917
}
918918

919919
void test_aplrregressor_group_mse_cycle()
@@ -957,7 +957,7 @@ class Tests
957957
save_as_csv_file("data/output.csv", predictions);
958958

959959
std::cout << predictions.mean() << "\n\n";
960-
tests.push_back(is_approximately_equal(predictions.mean(), 23.526475166355244));
960+
tests.push_back(is_approximately_equal(predictions.mean(), 23.529085584946195));
961961
}
962962

963963
void test_aplrregressor_int_constr()
@@ -1010,7 +1010,7 @@ class Tests
10101010
save_as_csv_file("data/output.csv", predictions);
10111011

10121012
std::cout << predictions.mean() << "\n\n";
1013-
tests.push_back(is_approximately_equal(predictions.mean(), 23.576830262038001));
1013+
tests.push_back(is_approximately_equal(predictions.mean(), 23.657546542794449));
10141014
}
10151015

10161016
void test_aplrregressor_inversegaussian()
@@ -1171,7 +1171,7 @@ class Tests
11711171
save_as_csv_file("data/output.csv", predictions);
11721172

11731173
std::cout << predictions.mean() << "\n\n";
1174-
tests.push_back(is_approximately_equal(predictions.mean(), 23.563270291507191));
1174+
tests.push_back(is_approximately_equal(predictions.mean(), 23.602543167509292));
11751175
}
11761176

11771177
void test_aplrregressor_monotonic()
@@ -1224,7 +1224,7 @@ class Tests
12241224
save_as_csv_file("data/output.csv", predictions);
12251225

12261226
std::cout << predictions.mean() << "\n\n";
1227-
tests.push_back(is_approximately_equal(predictions.mean(), 23.47597042545404));
1227+
tests.push_back(is_approximately_equal(predictions.mean(), 23.34283475003015));
12281228
}
12291229

12301230
void test_aplrregressor_monotonic_ignore_interactions()
@@ -1492,7 +1492,7 @@ class Tests
14921492
save_as_csv_file("data/output.csv", predictions);
14931493

14941494
std::cout << predictions.mean() << "\n\n";
1495-
tests.push_back(is_approximately_equal(predictions.mean(), 23.610872525541577));
1495+
tests.push_back(is_approximately_equal(predictions.mean(), 23.646255799722155));
14961496
}
14971497

14981498
void test_aplrregressor_weibull()
@@ -1597,19 +1597,19 @@ class Tests
15971597

15981598
VectorXd predictions{model.predict(X_test)};
15991599
MatrixXd li{model.calculate_local_feature_contribution(X_test)};
1600-
VectorXd li_for_particular_terms{model.calculate_local_contribution_from_selected_terms(X_train, {5, 1})};
1600+
VectorXd li_for_particular_terms{model.calculate_local_contribution_from_selected_terms(X_train, {1, 8})};
16011601

16021602
// Saving results
16031603
save_as_csv_file("data/output.csv", predictions);
16041604

16051605
std::cout << predictions.mean() << "\n\n";
1606-
tests.push_back(is_approximately_equal(predictions.mean(), 23.703500296203778, 0.00001));
1606+
tests.push_back(is_approximately_equal(predictions.mean(), 24.301339246925711, 0.00001));
16071607

16081608
std::map<double, double> main_effect_shape = model.get_main_effect_shape(1);
1609-
bool main_effect_shape_has_correct_length{main_effect_shape.size() == 11};
1610-
bool main_effect_shape_value_test{is_approximately_equal(main_effect_shape.begin()->second, -0.44924570143235887)};
1609+
bool main_effect_shape_has_correct_length{main_effect_shape.size() == 9};
1610+
bool main_effect_shape_value_test{is_approximately_equal(main_effect_shape.begin()->second, 0)};
16111611
bool li_for_particular_terms_has_correct_size{li_for_particular_terms.rows() == X_train.rows()};
1612-
bool li_for_particular_terms_mean_is_correct{is_approximately_equal(li_for_particular_terms.mean(), 0.30321952178814915)};
1612+
bool li_for_particular_terms_mean_is_correct{is_approximately_equal(li_for_particular_terms.mean(), -0.52786383485971788)};
16131613
tests.push_back(main_effect_shape_has_correct_length);
16141614
tests.push_back(main_effect_shape_value_test);
16151615
tests.push_back(li_for_particular_terms_has_correct_size);
@@ -2023,15 +2023,15 @@ class Tests
20232023

20242024
std::cout << "cv_error\n"
20252025
<< model.get_cv_error() << "\n\n";
2026-
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.16317052975361318, 0.000001));
2026+
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.15984656957508173, 0.000001));
20272027

20282028
std::cout << "predicted_class_prob_mean\n"
20292029
<< predicted_class_probabilities.mean() << "\n\n";
20302030
tests.push_back(is_approximately_equal(predicted_class_probabilities.mean(), 0.5, 0.00001));
20312031

20322032
std::cout << "local_feature_importance_mean\n"
20332033
<< local_feature_importance.mean() << "\n\n";
2034-
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.054997728196581296, 0.00001));
2034+
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.052181259967961045, 0.00001));
20352035
}
20362036

20372037
void test_aplrclassifier_two_class_predictor_specific_penalties_and_learning_rates()
@@ -2093,15 +2093,15 @@ class Tests
20932093

20942094
std::cout << "cv_error\n"
20952095
<< model.get_cv_error() << "\n\n";
2096-
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.16984042158451909, 0.000001));
2096+
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.17250319103503037, 0.000001));
20972097

20982098
std::cout << "predicted_class_prob_mean\n"
20992099
<< predicted_class_probabilities.mean() << "\n\n";
21002100
tests.push_back(is_approximately_equal(predicted_class_probabilities.mean(), 0.5, 0.00001));
21012101

21022102
std::cout << "local_feature_importance_mean\n"
21032103
<< local_feature_importance.mean() << "\n\n";
2104-
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.076147629914484025, 0.00001));
2104+
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.07920242388299352, 0.00001));
21052105
}
21062106

21072107
void test_aplrclassifier_two_class_max_terms()
Binary file not shown.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
setuptools.setup(
2727
name="aplr",
28-
version="10.2.0",
28+
version="10.2.1",
2929
description="Automatic Piecewise Linear Regression",
3030
ext_modules=[sfc_module],
3131
author="Mathias von Ottenbreit",

0 commit comments

Comments
 (0)