Skip to content

Commit a2f2ce0

Browse files
refactor
1 parent 0a531fc commit a2f2ce0

File tree

4 files changed

+20
-29
lines changed

4 files changed

+20
-29
lines changed

cpp/APLRRegressor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,8 @@ class APLRRegressor
231231
const MatrixXd &other_data = MatrixXd(0, 0));
232232
VectorXd predict(const MatrixXd &X, bool cap_predictions_to_minmax_in_training = true);
233233
void set_term_names(const std::vector<std::string> &X_names);
234-
VectorXd calculate_feature_importance(const MatrixXd &X, const VectorXd &sample_weight);
235-
VectorXd calculate_term_importance(const MatrixXd &X, const VectorXd &sample_weight);
234+
VectorXd calculate_feature_importance(const MatrixXd &X, const VectorXd &sample_weight = VectorXd(0));
235+
VectorXd calculate_term_importance(const MatrixXd &X, const VectorXd &sample_weight = VectorXd(0));
236236
MatrixXd calculate_local_feature_contribution(const MatrixXd &X);
237237
MatrixXd calculate_local_term_contribution(const MatrixXd &X);
238238
MatrixXd calculate_terms(const MatrixXd &X);

cpp/functions.h

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -509,19 +509,17 @@ std::vector<double> remove_duplicate_elements_from_vector(const std::vector<doub
509509
return output;
510510
}
511511

512-
double calculate_standard_deviation(const VectorXd &vector, const VectorXd &weight = VectorXd(0))
512+
double calculate_standard_deviation(const VectorXd &vector, const VectorXd &sample_weight = VectorXd(0))
513513
{
514-
bool weight_is_provided{weight.size() > 0};
515-
double variance;
516-
if (weight_is_provided)
517-
{
518-
double sum_weight{weight.sum()};
519-
double weighted_average_of_vector{(vector.array() * weight.array()).sum() / sum_weight};
520-
variance = (weight.array() * (vector.array() - weighted_average_of_vector).pow(2)).sum() / sum_weight;
521-
}
514+
VectorXd sample_weight_used;
515+
bool sample_weight_is_provided{sample_weight.size() > 0};
516+
if (sample_weight_is_provided)
517+
sample_weight_used = sample_weight / sample_weight.mean();
522518
else
523-
{
524-
variance = (vector.array() - vector.mean()).pow(2).mean();
525-
}
526-
return std::pow(variance, 0.5);
519+
sample_weight_used = VectorXd::Constant(vector.rows(), 1.0);
520+
double sum_weight{sample_weight_used.sum()};
521+
double weighted_average_of_vector{(vector.array() * sample_weight_used.array()).sum() / sum_weight};
522+
double variance{(sample_weight_used.array() * (vector.array() - weighted_average_of_vector).pow(2)).sum() / sum_weight};
523+
double standard_deviation{std::pow(variance, 0.5)};
524+
return standard_deviation;
527525
}

cpp/term.h

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -550,21 +550,11 @@ void Term::estimate_split_point_on_discretized_data()
550550

551551
void Term::estimate_coefficient_and_error(const VectorXd &x, const VectorXd &y, const VectorXd &sample_weight, double error_added)
552552
{
553-
coefficient = estimate_coefficient(x, y, sample_weight);
554-
if (std::isfinite(coefficient))
553+
coefficient = v * estimate_coefficient(x, y, sample_weight);
554+
if (std::isfinite(coefficient) && coefficient_adheres_to_monotonic_constraint())
555555
{
556-
coefficient *= v;
557-
bool coefficient_does_not_adhere_to_monotonic_constraint{!coefficient_adheres_to_monotonic_constraint()};
558-
if (coefficient_does_not_adhere_to_monotonic_constraint)
559-
{
560-
coefficient = 0.0;
561-
split_point_search_errors_sum = std::numeric_limits<double>::infinity();
562-
}
563-
else
564-
{
565-
VectorXd predictions{x * coefficient};
566-
split_point_search_errors_sum = calculate_sum_error(calculate_errors(y, predictions, sample_weight, MSE_LOSS_FUNCTION)) + error_added;
567-
}
556+
VectorXd predictions{x * coefficient};
557+
split_point_search_errors_sum = calculate_sum_error(calculate_errors(y, predictions, sample_weight, MSE_LOSS_FUNCTION)) + error_added;
568558
}
569559
else
570560
{

cpp/tests.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,8 @@ class Tests
117117
std::cout << predictions.mean() << "\n\n";
118118
tests.push_back(is_approximately_equal(predictions.mean(), 20.170939369337834));
119119

120+
VectorXd feature_importance_on_test_set{model.calculate_feature_importance(X_test)};
121+
double feature_importance_on_test_set_mean{feature_importance_on_test_set.mean()};
120122
double feature_importance_mean{model.get_feature_importance().mean()};
121123
double term_importance_mean{model.get_term_importance().mean()};
122124
double feature_importance_first{model.get_feature_importance()[0]};
@@ -129,6 +131,7 @@ class Tests
129131
std::cout << term_importance_first << "\n\n";
130132
std::cout << term_base_predictor_index_max << "\n\n";
131133
std::cout << term_interaction_level_max << "\n\n";
134+
tests.push_back(is_approximately_equal(feature_importance_on_test_set_mean, 0.3745208543129413));
132135
tests.push_back(is_approximately_equal(feature_importance_mean, 0.37558643075803277));
133136
tests.push_back(is_approximately_equal(term_importance_mean, 0.12150899891033366));
134137
tests.push_back(is_approximately_equal(feature_importance_first, 0.74048121167747938));

0 commit comments

Comments
 (0)