Skip to content

Commit d611b57

Browse files
interactions
1 parent 4ca2628 commit d611b57

File tree

6 files changed

+57
-46
lines changed

6 files changed

+57
-46
lines changed

API_REFERENCE_FOR_CLASSIFICATION.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Specifies the maximum number of bins to discretize the data into when searching
2929
Specifies the maximum allowed depth of interaction terms. ***0*** means that interactions are not allowed. This hyperparameter should be tuned.
3030

3131
#### max_interactions (default = 100000)
32-
The maximum number of interactions allowed. A lower value may be used to reduce computational time.
32+
The maximum number of interactions allowed in each underlying model. A lower value may be used to reduce computational time.
3333

3434
#### min_observations_in_split (default = 20)
3535
The minimum effective number of observations that a term in the model must rely on. This hyperparameter should be tuned. Larger values are more appropriate for larger datasets. Larger values result in more robust models (lower variance), potentially at the expense of increased bias.

API_REFERENCE_FOR_REGRESSION.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Specifies the maximum number of bins to discretize the data into when searching
3232
Specifies the maximum allowed depth of interaction terms. ***0*** means that interactions are not allowed. This hyperparameter should be tuned.
3333

3434
#### max_interactions (default = 100000)
35-
The maximum number of interactions allowed. A lower value may be used to reduce computational time.
35+
The maximum number of interactions allowed in each underlying model. A lower value may be used to reduce computational time.
3636

3737
#### min_observations_in_split (default = 20)
3838
The minimum effective number of observations that a term in the model must rely on. This hyperparameter should be tuned. Larger values are more appropriate for larger datasets. Larger values result in more robust models (lower variance), potentially at the expense of increased bias.

cpp/APLRRegressor.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1524,7 +1524,8 @@ void APLRRegressor::add_promising_interactions_and_select_the_best_one()
15241524
bool is_best_interaction{j == 0};
15251525
if (is_best_interaction)
15261526
best_term_index = terms_eligible_current.size() - 1;
1527-
++interactions_eligible;
1527+
if (interactions_to_consider[sorted_indexes_of_errors_for_interactions_to_consider[j]].get_interaction_level() > 0)
1528+
++interactions_eligible;
15281529
}
15291530
else
15301531
break;

cpp/term.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,17 @@ VectorXd Term::calculate_contribution_to_linear_predictor(const MatrixXd &X)
715715

716716
size_t Term::get_interaction_level()
717717
{
718-
return given_terms.size();
718+
std::vector<size_t> terms_used;
719+
terms_used.reserve(1 + given_terms.size());
720+
terms_used.push_back(base_term);
721+
for (auto &given_term : given_terms)
722+
{
723+
terms_used.push_back(given_term.base_term);
724+
}
725+
std::set<size_t> unique_predictors_used{get_unique_integers(terms_used)};
726+
size_t interaction_level{unique_predictors_used.size() - 1};
727+
728+
return interaction_level;
719729
}
720730

721731
bool Term::get_can_be_used_as_a_given_term()

cpp/tests.cpp

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class Tests
167167
save_as_csv_file("data/output.csv", predictions);
168168

169169
std::cout << predictions.mean() << "\n\n";
170-
tests.push_back(is_approximately_equal(predictions.mean(), 18.507179601395375));
170+
tests.push_back(is_approximately_equal(predictions.mean(), 18.534016846656947));
171171
}
172172

173173
void test_aplrregressor_cauchy_penalties()
@@ -221,7 +221,7 @@ class Tests
221221
save_as_csv_file("data/output.csv", predictions);
222222

223223
std::cout << predictions.mean() << "\n\n";
224-
tests.push_back(is_approximately_equal(predictions.mean(), 20.037964498663037));
224+
tests.push_back(is_approximately_equal(predictions.mean(), 20.146282076477394));
225225
}
226226

227227
void test_aplrregressor_cauchy_linear_effects_only_first()
@@ -275,7 +275,7 @@ class Tests
275275
save_as_csv_file("data/output.csv", predictions);
276276

277277
std::cout << predictions.mean() << "\n\n";
278-
tests.push_back(is_approximately_equal(predictions.mean(), 17.971623307698852));
278+
tests.push_back(is_approximately_equal(predictions.mean(), 17.964887018234787));
279279
}
280280

281281
void test_aplrregressor_cauchy_group_mse_validation()
@@ -333,7 +333,7 @@ class Tests
333333
save_as_csv_file("data/output.csv", predictions);
334334

335335
std::cout << predictions.mean() << "\n\n";
336-
tests.push_back(is_approximately_equal(predictions.mean(), 20.170939369337834));
336+
tests.push_back(is_approximately_equal(predictions.mean(), 20.096177156192478));
337337

338338
VectorXd feature_importance_on_test_set{model.calculate_feature_importance(X_test)};
339339
double feature_importance_on_test_set_mean{feature_importance_on_test_set.mean()};
@@ -349,12 +349,12 @@ class Tests
349349
std::cout << term_importance_first << "\n\n";
350350
std::cout << term_base_predictor_index_max << "\n\n";
351351
std::cout << term_interaction_level_max << "\n\n";
352-
tests.push_back(is_approximately_equal(feature_importance_on_test_set_mean, 0.3745208543129413));
353-
tests.push_back(is_approximately_equal(feature_importance_mean, 0.37558643075803277));
354-
tests.push_back(is_approximately_equal(term_importance_mean, 0.12150899891033366));
355-
tests.push_back(is_approximately_equal(feature_importance_first, 0.74048121167747938));
356-
tests.push_back(is_approximately_equal(term_importance_first, 0.85496610382351568));
357-
tests.push_back(term_base_predictor_index_max == 5);
352+
tests.push_back(is_approximately_equal(feature_importance_on_test_set_mean, 0.37735253878466402));
353+
tests.push_back(is_approximately_equal(feature_importance_mean, 0.37820511700233239));
354+
tests.push_back(is_approximately_equal(term_importance_mean, 0.12843198080249971));
355+
tests.push_back(is_approximately_equal(feature_importance_first, 0.73797521341670724));
356+
tests.push_back(is_approximately_equal(term_importance_first, 1.0431553101537596));
357+
tests.push_back(term_base_predictor_index_max == 6);
358358
tests.push_back(term_interaction_level_max == 1);
359359
}
360360

@@ -412,7 +412,7 @@ class Tests
412412
save_as_csv_file("data/output.csv", predictions);
413413

414414
std::cout << predictions.mean() << "\n\n";
415-
tests.push_back(is_approximately_equal(predictions.mean(), 20.104856610039558));
415+
tests.push_back(is_approximately_equal(predictions.mean(), 19.804431518585918));
416416
}
417417

418418
void test_aplrregressor_cauchy()
@@ -466,7 +466,7 @@ class Tests
466466
save_as_csv_file("data/output.csv", predictions);
467467

468468
std::cout << predictions.mean() << "\n\n";
469-
tests.push_back(is_approximately_equal(predictions.mean(), 20.850035037781723));
469+
tests.push_back(is_approximately_equal(predictions.mean(), 20.873594934501561));
470470
}
471471

472472
void test_aplrregressor_custom_loss_and_validation()
@@ -526,7 +526,7 @@ class Tests
526526
save_as_csv_file("data/output.csv", predictions);
527527

528528
std::cout << predictions.mean() << "\n\n";
529-
tests.push_back(is_approximately_equal(predictions.mean(), 23.944797684016745));
529+
tests.push_back(is_approximately_equal(predictions.mean(), 23.91507568241019));
530530
}
531531

532532
void test_aplrregressor_custom_loss()
@@ -585,7 +585,7 @@ class Tests
585585
save_as_csv_file("data/output.csv", predictions);
586586

587587
std::cout << predictions.mean() << "\n\n";
588-
tests.push_back(is_approximately_equal(predictions.mean(), 23.7035, 0.00001));
588+
tests.push_back(is_approximately_equal(predictions.mean(), 23.703500296203778, 0.00001));
589589
}
590590

591591
void test_aplrregressor_gamma_custom_link()
@@ -640,7 +640,7 @@ class Tests
640640
save_as_csv_file("data/output.csv", predictions);
641641

642642
std::cout << predictions.mean() << "\n\n";
643-
tests.push_back(is_approximately_equal(predictions.mean(), 23.5266, 0.00001));
643+
tests.push_back(is_approximately_equal(predictions.mean(), 23.526613939603266, 0.00001));
644644
}
645645

646646
void test_aplrregressor_gamma_custom_validation()
@@ -695,7 +695,7 @@ class Tests
695695
save_as_csv_file("data/output.csv", predictions);
696696

697697
std::cout << predictions.mean() << "\n\n";
698-
tests.push_back(is_approximately_equal(predictions.mean(), 23.5512, 0.00001));
698+
tests.push_back(is_approximately_equal(predictions.mean(), 23.551175298027964, 0.00001));
699699
}
700700

701701
void test_aplrregressor_gamma_gini_weighted()
@@ -749,7 +749,7 @@ class Tests
749749
save_as_csv_file("data/output.csv", predictions);
750750

751751
std::cout << predictions.mean() << "\n\n";
752-
tests.push_back(is_approximately_equal(predictions.mean(), 23.3198, 0.00001));
752+
tests.push_back(is_approximately_equal(predictions.mean(), 23.319789512734854, 0.00001));
753753
}
754754

755755
void test_aplrregressor_gamma_gini()
@@ -803,7 +803,7 @@ class Tests
803803
save_as_csv_file("data/output.csv", predictions);
804804

805805
std::cout << predictions.mean() << "\n\n";
806-
tests.push_back(is_approximately_equal(predictions.mean(), 23.3198, 0.00001));
806+
tests.push_back(is_approximately_equal(predictions.mean(), 23.319789512734854, 0.00001));
807807
}
808808

809809
void test_aplrregressor_gamma()
@@ -857,7 +857,7 @@ class Tests
857857
save_as_csv_file("data/output.csv", predictions);
858858

859859
std::cout << predictions.mean() << "\n\n";
860-
tests.push_back(is_approximately_equal(predictions.mean(), 23.5512, 0.00001));
860+
tests.push_back(is_approximately_equal(predictions.mean(), 23.551175298027964, 0.00001));
861861
}
862862

863863
void test_aplrregressor_group_mse()
@@ -957,7 +957,7 @@ class Tests
957957
save_as_csv_file("data/output.csv", predictions);
958958

959959
std::cout << predictions.mean() << "\n\n";
960-
tests.push_back(is_approximately_equal(predictions.mean(), 23.533140818895273));
960+
tests.push_back(is_approximately_equal(predictions.mean(), 23.526475166355244));
961961
}
962962

963963
void test_aplrregressor_int_constr()
@@ -1010,7 +1010,7 @@ class Tests
10101010
save_as_csv_file("data/output.csv", predictions);
10111011

10121012
std::cout << predictions.mean() << "\n\n";
1013-
tests.push_back(is_approximately_equal(predictions.mean(), 23.606326522845816));
1013+
tests.push_back(is_approximately_equal(predictions.mean(), 23.576830262038001));
10141014
}
10151015

10161016
void test_aplrregressor_inversegaussian()
@@ -1065,7 +1065,7 @@ class Tests
10651065
save_as_csv_file("data/output.csv", predictions);
10661066

10671067
std::cout << predictions.mean() << "\n\n";
1068-
tests.push_back(is_approximately_equal(predictions.mean(), 23.3198, 0.00001));
1068+
tests.push_back(is_approximately_equal(predictions.mean(), 23.31977985222057, 0.00001));
10691069
}
10701070

10711071
void test_aplrregressor_logit()
@@ -1118,7 +1118,7 @@ class Tests
11181118
save_as_csv_file("data/output.csv", predictions);
11191119

11201120
std::cout << predictions.mean() << "\n\n";
1121-
tests.push_back(is_approximately_equal(predictions.mean(), 0.0875969, 0.00001));
1121+
tests.push_back(is_approximately_equal(predictions.mean(), 0.087596882912220717, 0.00001));
11221122
}
11231123

11241124
void test_aplrregressor_mae()
@@ -1171,7 +1171,7 @@ class Tests
11711171
save_as_csv_file("data/output.csv", predictions);
11721172

11731173
std::cout << predictions.mean() << "\n\n";
1174-
tests.push_back(is_approximately_equal(predictions.mean(), 23.557834093496929));
1174+
tests.push_back(is_approximately_equal(predictions.mean(), 23.563270291507191));
11751175
}
11761176

11771177
void test_aplrregressor_monotonic()
@@ -1224,7 +1224,7 @@ class Tests
12241224
save_as_csv_file("data/output.csv", predictions);
12251225

12261226
std::cout << predictions.mean() << "\n\n";
1227-
tests.push_back(is_approximately_equal(predictions.mean(), 23.337125831228487));
1227+
tests.push_back(is_approximately_equal(predictions.mean(), 23.47597042545404));
12281228
}
12291229

12301230
void test_aplrregressor_monotonic_ignore_interactions()
@@ -1278,7 +1278,7 @@ class Tests
12781278
save_as_csv_file("data/output.csv", predictions);
12791279

12801280
std::cout << predictions.mean() << "\n\n";
1281-
tests.push_back(is_approximately_equal(predictions.mean(), 24.3013, 0.00001));
1281+
tests.push_back(is_approximately_equal(predictions.mean(), 24.301339246925711, 0.00001));
12821282
}
12831283

12841284
void test_aplrregressor_negative_binomial()
@@ -1332,7 +1332,7 @@ class Tests
13321332
save_as_csv_file("data/output.csv", predictions);
13331333

13341334
std::cout << predictions.mean() << "\n\n";
1335-
tests.push_back(is_approximately_equal(predictions.mean(), 1.8694, 0.00001));
1335+
tests.push_back(is_approximately_equal(predictions.mean(), 1.8694002118421278, 0.00001));
13361336
}
13371337

13381338
void test_aplrregressor_poisson()
@@ -1385,7 +1385,7 @@ class Tests
13851385
save_as_csv_file("data/output.csv", predictions);
13861386

13871387
std::cout << predictions.mean() << "\n\n";
1388-
tests.push_back(is_approximately_equal(predictions.mean(), 1.88727, 0.00001));
1388+
tests.push_back(is_approximately_equal(predictions.mean(), 1.8872692088161898, 0.00001));
13891389
}
13901390

13911391
void test_aplrregressor_poissongamma()
@@ -1439,7 +1439,7 @@ class Tests
14391439
save_as_csv_file("data/output.csv", predictions);
14401440

14411441
std::cout << predictions.mean() << "\n\n";
1442-
tests.push_back(is_approximately_equal(predictions.mean(), 1.88553, 0.00001));
1442+
tests.push_back(is_approximately_equal(predictions.mean(), 1.8855344167602603, 0.00001));
14431443
}
14441444

14451445
void test_aplrregressor_quantile()
@@ -1492,7 +1492,7 @@ class Tests
14921492
save_as_csv_file("data/output.csv", predictions);
14931493

14941494
std::cout << predictions.mean() << "\n\n";
1495-
tests.push_back(is_approximately_equal(predictions.mean(), 23.65630148869738));
1495+
tests.push_back(is_approximately_equal(predictions.mean(), 23.610872525541577));
14961496
}
14971497

14981498
void test_aplrregressor_weibull()
@@ -1546,7 +1546,7 @@ class Tests
15461546
save_as_csv_file("data/output.csv", predictions);
15471547

15481548
std::cout << predictions.mean() << "\n\n";
1549-
tests.push_back(is_approximately_equal(predictions.mean(), 23.6406, 0.00001));
1549+
tests.push_back(is_approximately_equal(predictions.mean(), 23.640555263512187, 0.00001));
15501550
}
15511551

15521552
void test_aplrregressor()
@@ -1603,7 +1603,7 @@ class Tests
16031603
save_as_csv_file("data/output.csv", predictions);
16041604

16051605
std::cout << predictions.mean() << "\n\n";
1606-
tests.push_back(is_approximately_equal(predictions.mean(), 23.7035, 0.00001));
1606+
tests.push_back(is_approximately_equal(predictions.mean(), 23.703500296203778, 0.00001));
16071607

16081608
std::map<double, double> main_effect_shape = model.get_main_effect_shape(1);
16091609
bool main_effect_shape_has_correct_length{main_effect_shape.size() == 11};
@@ -1688,15 +1688,15 @@ class Tests
16881688

16891689
std::cout << "cv_error\n"
16901690
<< model.get_cv_error() << "\n\n";
1691-
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.246477, 0.000001));
1691+
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.24647671959943313, 0.000001));
16921692

16931693
std::cout << "predicted_class_prob_mean\n"
16941694
<< predicted_class_probabilities.mean() << "\n\n";
16951695
tests.push_back(is_approximately_equal(predicted_class_probabilities.mean(), 0.2, 0.00001));
16961696

16971697
std::cout << "local_feature_importance_mean\n"
16981698
<< local_feature_importance.mean() << "\n\n";
1699-
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.15805, 0.00001));
1699+
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.1580504780375889, 0.00001));
17001700
}
17011701

17021702
void test_aplrclassifier_multi_class()
@@ -1911,15 +1911,15 @@ class Tests
19111911

19121912
std::cout << "cv_error\n"
19131913
<< model.get_cv_error() << "\n\n";
1914-
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.22862513689095387));
1914+
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.23802511407945728));
19151915

19161916
std::cout << "predicted_class_prob_mean\n"
19171917
<< predicted_class_probabilities.mean() << "\n\n";
19181918
tests.push_back(is_approximately_equal(predicted_class_probabilities.mean(), 0.5, 0.00001));
19191919

19201920
std::cout << "local_feature_importance_mean\n"
19211921
<< local_feature_importance.mean() << "\n\n";
1922-
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.14062146736733369));
1922+
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.13431844066700888));
19231923
}
19241924

19251925
void test_aplrclassifier_two_class()
@@ -2049,15 +2049,15 @@ class Tests
20492049

20502050
std::cout << "cv_error\n"
20512051
<< model.get_cv_error() << "\n\n";
2052-
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.16172925575492014, 0.000001));
2052+
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.16317052975361318, 0.000001));
20532053

20542054
std::cout << "predicted_class_prob_mean\n"
20552055
<< predicted_class_probabilities.mean() << "\n\n";
20562056
tests.push_back(is_approximately_equal(predicted_class_probabilities.mean(), 0.5, 0.00001));
20572057

20582058
std::cout << "local_feature_importance_mean\n"
20592059
<< local_feature_importance.mean() << "\n\n";
2060-
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.12241915926968391, 0.00001));
2060+
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.12221717377018071, 0.00001));
20612061
}
20622062

20632063
void test_aplrclassifier_two_class_predictor_specific_penalties_and_learning_rates()
@@ -2119,15 +2119,15 @@ class Tests
21192119

21202120
std::cout << "cv_error\n"
21212121
<< model.get_cv_error() << "\n\n";
2122-
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.17021028319567164, 0.000001));
2122+
tests.push_back(is_approximately_equal(model.get_cv_error(), 0.16984042158451909, 0.000001));
21232123

21242124
std::cout << "predicted_class_prob_mean\n"
21252125
<< predicted_class_probabilities.mean() << "\n\n";
21262126
tests.push_back(is_approximately_equal(predicted_class_probabilities.mean(), 0.5, 0.00001));
21272127

21282128
std::cout << "local_feature_importance_mean\n"
21292129
<< local_feature_importance.mean() << "\n\n";
2130-
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.18697312064762112, 0.00001));
2130+
tests.push_back(is_approximately_equal(local_feature_importance.mean(), 0.18613865090207235, 0.00001));
21312131
}
21322132

21332133
void test_aplrclassifier_two_class_max_terms()
@@ -2368,7 +2368,7 @@ class Tests
23682368
size_t p7il{p7.get_interaction_level()};
23692369
size_t p8il{p8.get_interaction_level()};
23702370
tests.push_back(pil == 2 ? true : false);
2371-
tests.push_back(p5il == 3 ? true : false);
2371+
tests.push_back(p5il == 2 ? true : false);
23722372
tests.push_back(p7il == 1 ? true : false);
23732373
tests.push_back(p8il == 0 ? true : false);
23742374
}

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
setuptools.setup(
2727
name="aplr",
28-
version="10.0.0",
28+
version="10.1.0",
2929
description="Automatic Piecewise Linear Regression",
3030
ext_modules=[sfc_module],
3131
author="Mathias von Ottenbreit",

0 commit comments

Comments
 (0)