Skip to content

Commit 7870300

Browse files
fixed bug that sometimes formed unnecessarily complex interactions
1 parent 1a9f5b4 commit 7870300

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

cpp/APLRRegressor.h

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -733,16 +733,28 @@ VectorXi APLRRegressor::find_indexes_for_terms_to_consider_as_interaction_partne
733733
size_t number_of_terms_to_consider_as_interaction_partners{find_out_how_many_terms_to_consider_as_interaction_partners()};
734734
VectorXd split_point_errors(terms.size());
735735
VectorXi indexes_for_terms_to_consider_as_interaction_partners(terms.size());
736+
size_t count{0};
736737
for (size_t i = 0; i < terms.size(); ++i)
737738
{
738-
split_point_errors[i]=terms[i].split_point_search_errors_sum;
739-
indexes_for_terms_to_consider_as_interaction_partners[i]=i;
739+
if(terms[i].can_be_used_as_a_given_term)
740+
{
741+
split_point_errors[count] = terms[i].split_point_search_errors_sum;
742+
indexes_for_terms_to_consider_as_interaction_partners[count] = i;
743+
++count;
744+
}
740745
}
741-
bool selecting_the_terms_with_lowest_previous_errors_is_necessary{max_eligible_terms<terms.size()};
746+
split_point_errors.conservativeResize(count);
747+
indexes_for_terms_to_consider_as_interaction_partners.conservativeResize(count);
748+
bool selecting_the_terms_with_lowest_previous_errors_is_necessary{number_of_terms_to_consider_as_interaction_partners < indexes_for_terms_to_consider_as_interaction_partners.size()};
742749
if(selecting_the_terms_with_lowest_previous_errors_is_necessary)
743750
{
744-
indexes_for_terms_to_consider_as_interaction_partners=sort_indexes_ascending(split_point_errors);
745-
indexes_for_terms_to_consider_as_interaction_partners.conservativeResize(number_of_terms_to_consider_as_interaction_partners);
751+
VectorXi sorted_indexes{sort_indexes_ascending(split_point_errors)};
752+
VectorXi temp_indexes(number_of_terms_to_consider_as_interaction_partners);
753+
for (size_t i = 0; i < number_of_terms_to_consider_as_interaction_partners; ++i)
754+
{
755+
temp_indexes[i] = indexes_for_terms_to_consider_as_interaction_partners[sorted_indexes[i]];
756+
}
757+
indexes_for_terms_to_consider_as_interaction_partners=std::move(temp_indexes);
746758
}
747759
return indexes_for_terms_to_consider_as_interaction_partners;
748760
}
@@ -756,6 +768,7 @@ size_t APLRRegressor::find_out_how_many_terms_to_consider_as_interaction_partner
756768
}
757769
return terms_to_consider;
758770
}
771+
759772
void APLRRegressor::find_sorted_indexes_for_errors_for_interactions_to_consider()
760773
{
761774
VectorXd errors_for_interactions_to_consider(interactions_to_consider.size());

cpp/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ int main()
4646
//Saving results
4747
save_data("output.csv",predictions);
4848
std::cout<<"min validation_error "<<model.validation_error_steps.minCoeff()<<"\n\n";
49-
std::cout<<is_approximately_equal(model.validation_error_steps.minCoeff(),5.82111,0.00001)<<"\n";
49+
std::cout<<is_approximately_equal(model.validation_error_steps.minCoeff(),6.32895,0.00001)<<"\n";
5050

5151
std::cout<<"mean prediction "<<predictions.mean()<<"\n\n";
52-
std::cout<<is_approximately_equal(predictions.mean(),23.5815,0.0001)<<"\n";
52+
std::cout<<is_approximately_equal(predictions.mean(),23.587,0.0001)<<"\n";
5353

5454
std::cout<<"best_m: "<<model.m<<"\n";
5555

cpp/term.h

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class Term
5757
void cleanup_after_fit();
5858
void cleanup_when_this_term_was_added_as_a_given_term();
5959
void make_term_ineligible();
60+
void determine_if_can_be_used_as_a_given_term(const VectorXd &x);
6061

6162
public:
6263
//fields
@@ -75,6 +76,7 @@ class Term
7576
size_t ineligible_boosting_steps;
7677
VectorXd values_discretized; //Discretized values based on split_point=nan
7778
VectorXd sample_weight_discretized; //Discretized sample_weight based on split_point=nan
79+
bool can_be_used_as_a_given_term;
7880

7981
//methods
8082
Term(size_t base_term=0,const std::vector<Term> &given_terms=std::vector<Term>(0),double split_point=NAN_DOUBLE,bool direction_right=false,double coefficient=0);
@@ -97,14 +99,15 @@ class Term
9799
//Regular constructor
98100
Term::Term(size_t base_term,const std::vector<Term> &given_terms,double split_point,bool direction_right,double coefficient):
99101
name{""},base_term{base_term},given_terms{given_terms},split_point{split_point},direction_right{direction_right},coefficient{coefficient},
100-
split_point_search_errors_sum{std::numeric_limits<double>::infinity()},ineligible_boosting_steps{0}
102+
split_point_search_errors_sum{std::numeric_limits<double>::infinity()},ineligible_boosting_steps{0},can_be_used_as_a_given_term{false}
101103
{
102104
}
103105

104106
//Copy constructor
105107
Term::Term(const Term &other):
106108
name{other.name},base_term{other.base_term},given_terms{other.given_terms},split_point{other.split_point},direction_right{other.direction_right},
107-
coefficient{other.coefficient},coefficient_steps{other.coefficient_steps},split_point_search_errors_sum{other.split_point_search_errors_sum},ineligible_boosting_steps{0}
109+
coefficient{other.coefficient},coefficient_steps{other.coefficient_steps},split_point_search_errors_sum{other.split_point_search_errors_sum},
110+
ineligible_boosting_steps{0},can_be_used_as_a_given_term{other.can_be_used_as_a_given_term}
108111
{
109112
}
110113

@@ -180,6 +183,7 @@ void Term::estimate_split_point(const MatrixXd &X,const VectorXd &negative_gradi
180183
estimate_split_point_on_discretized_data();
181184
estimate_coefficient_and_error_on_all_data();
182185
cleanup_after_estimate_split_point();
186+
determine_if_can_be_used_as_a_given_term(X.col(base_term));
183187
}
184188

185189
//Calculate indices that get zeroed out during calculate() because of given terms. Also calculates indices of those observations that do not.
@@ -642,6 +646,20 @@ void Term::cleanup_after_estimate_split_point()
642646
errors_initial.resize(0);
643647
}
644648

649+
void Term::determine_if_can_be_used_as_a_given_term(const VectorXd &x)
650+
{
651+
VectorXd values{calculate_without_interactions(x)};
652+
can_be_used_as_a_given_term = false;
653+
for (auto &value:values)
654+
{
655+
if(is_approximately_zero(value))
656+
{
657+
can_be_used_as_a_given_term=true;
658+
break;
659+
}
660+
}
661+
}
662+
645663
void Term::cleanup_after_fit()
646664
{
647665
bins_start_index.clear();

cpp/test ALRRegressor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ int main()
5252
save_data("data/output.csv",predictions);
5353

5454
std::cout<<predictions.mean()<<"\n\n";
55-
tests.push_back(is_approximately_equal(predictions.mean(),23.6889,0.00001));
55+
tests.push_back(is_approximately_equal(predictions.mean(),23.5049,0.00001));
5656

5757
//std::cout<<model.validation_error_steps<<"\n\n";
5858

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
setuptools.setup(
1717
name='aplr',
18-
version='1.10.0',
18+
version='1.10.1',
1919
description='Automatic Piecewise Linear Regression',
2020
ext_modules=[sfc_module],
2121
author="Mathias von Ottenbreit",

0 commit comments

Comments
 (0)