Skip to content

Commit 79ec862

Browse files
naming predictors improved
1 parent 68dc316 commit 79ec862

File tree

1 file changed

+44
-26
lines changed

1 file changed

+44
-26
lines changed

cpp/APLRRegressor.h

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -977,37 +977,17 @@ void APLRRegressor::name_terms(const MatrixXd &X, const std::vector<std::string>
977977
//These names will be used to derive names for the actually used terms in the trained model.
978978
void APLRRegressor::set_term_names(const std::vector<std::string> &X_names)
979979
{
980-
if(std::isnan(intercept)) //model has not been trained
980+
bool model_has_not_been_trained{!std::isfinite(intercept)};
981+
if(model_has_not_been_trained)
981982
throw std::runtime_error("The model must be trained with fit() before term names can be set.");
982983

983-
for (size_t i = 0; i < terms.size(); ++i) //for each term
984+
for (size_t i = 0; i < terms.size(); ++i)
984985
{
985-
//Base name
986-
terms[i].name=X_names[terms[i].base_term];
987-
988-
//Adding cut-point and direction
989-
if(!std::isnan(terms[i].split_point)) //If not linear effect
990-
{
991-
double temp_split_point{terms[i].split_point}; //For prettier printing (+5.0 instead 0f --5.0 as an example when split_point is negative)
992-
std::string sign{"-"};
993-
if(std::isless(temp_split_point,0))
994-
{
995-
temp_split_point=-temp_split_point;
996-
sign="+";
997-
}
998-
if(terms[i].direction_right)
999-
terms[i].name="max("+terms[i].name+sign+std::to_string(temp_split_point)+",0)";
1000-
else
1001-
terms[i].name="min("+terms[i].name+sign+std::to_string(temp_split_point)+",0)";
1002-
}
1003-
1004-
//Adding given terms
1005-
for (size_t j = 0; j < terms[i].given_terms.size(); ++j) //for each given term
986+
terms[i].name = compute_raw_base_term_name(terms[i], X_names[terms[i].base_term]);
987+
for (size_t j = 0; j < terms[i].given_terms.size(); ++j)
1006988
{
1007-
terms[i].name+=" * I("+terms[i].given_terms[j].name+"!=0)";
989+
terms[i].name += " * I("+compute_raw_given_term_name(terms[i].given_terms[j], X_names)+"!=0)";
1008990
}
1009-
1010-
//Adding interaction level
1011991
terms[i].name="P"+std::to_string(i)+". Interaction level: "+std::to_string(terms[i].get_interaction_level())+". "+terms[i].name;
1012992
}
1013993

@@ -1023,6 +1003,44 @@ void APLRRegressor::set_term_names(const std::vector<std::string> &X_names)
10231003
}
10241004
}
10251005

1006+
std::string APLRRegressor::compute_raw_base_term_name(const Term &term, const std::string &X_name)
1007+
{
1008+
std::string name{""};
1009+
bool is_linear_effect{std::isnan(term.split_point)};
1010+
if(is_linear_effect)
1011+
name=X_name;
1012+
else
1013+
{
1014+
double temp_split_point{term.split_point};
1015+
std::string sign{"-"};
1016+
if(std::isless(temp_split_point,0))
1017+
{
1018+
temp_split_point=-temp_split_point;
1019+
sign="+";
1020+
}
1021+
if(term.direction_right)
1022+
name="max("+X_name+sign+std::to_string(temp_split_point)+",0)";
1023+
else
1024+
name="min("+X_name+sign+std::to_string(temp_split_point)+",0)";
1025+
}
1026+
return name;
1027+
}
1028+
1029+
std::string APLRRegressor::compute_raw_given_term_name(const Term &term, const std::vector<std::string> &X_names)
1030+
{
1031+
std::string name{compute_raw_base_term_name(term, X_names[term.base_term])};
1032+
bool term_has_interactions{term.given_terms.size()>0};
1033+
if(term_has_interactions)
1034+
{
1035+
for (size_t i = 0; i < term.given_terms.size(); ++i)
1036+
{
1037+
name += "*" + compute_raw_given_term_name(term.given_terms[i], X_names);
1038+
}
1039+
}
1040+
1041+
return name;
1042+
}
1043+
10261044
void APLRRegressor::calculate_feature_importance_on_validation_set()
10271045
{
10281046
feature_importance=VectorXd::Constant(number_of_base_terms,0);

0 commit comments

Comments
 (0)