Skip to content

Commit 6859bbf

Browse files
refactor
1 parent ea1c70d commit 6859bbf

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

cpp/APLRRegressor.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class APLRRegressor
3131
VectorXd predictions_current;
3232
VectorXd predictions_current_validation;
3333
VectorXd neg_gradient_current;
34-
VectorXd neg_gradient_nullmodel_errors;
3534
double neg_gradient_nullmodel_errors_sum;
3635
size_t best_term_index;
3736
VectorXd linear_predictor_update;
@@ -630,8 +629,7 @@ void APLRRegressor::update_linear_predictor_and_predictions()
630629
void APLRRegressor::update_gradient_and_errors()
631630
{
632631
neg_gradient_current=calculate_neg_gradient_current();
633-
neg_gradient_nullmodel_errors=calculate_errors(neg_gradient_current,linear_predictor_null_model,sample_weight_train);
634-
neg_gradient_nullmodel_errors_sum=calculate_sum_error(neg_gradient_nullmodel_errors);
632+
neg_gradient_nullmodel_errors_sum=calculate_sum_error(calculate_errors(neg_gradient_current,linear_predictor_null_model,sample_weight_train,FAMILY_GAUSSIAN));
635633
}
636634

637635
std::vector<size_t> APLRRegressor::find_terms_eligible_current_indexes_for_a_base_term(size_t base_term)
@@ -908,7 +906,7 @@ void APLRRegressor::select_the_best_term_and_update_errors(size_t boosting_step,
908906

909907
linear_predictor_update=terms_eligible_current[best_term_index].calculate_contribution_to_linear_predictor(X_train);
910908
linear_predictor_update_validation=terms_eligible_current[best_term_index].calculate_contribution_to_linear_predictor(X_validation);
911-
double error_after_updating_term=calculate_sum_error(calculate_errors(neg_gradient_current,linear_predictor_update,sample_weight_train));
909+
double error_after_updating_term=calculate_sum_error(calculate_errors(neg_gradient_current,linear_predictor_update,sample_weight_train,FAMILY_GAUSSIAN));
912910
bool no_improvement{std::isgreaterequal(error_after_updating_term,neg_gradient_nullmodel_errors_sum)};
913911
if(no_improvement)
914912
{
@@ -985,7 +983,7 @@ void APLRRegressor::calculate_validation_error(size_t boosting_step, const Vecto
985983
if(validation_tuning_metric=="default")
986984
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation,family,tweedie_power),sample_weight_validation);
987985
else if(validation_tuning_metric=="mse")
988-
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation),sample_weight_validation);
986+
validation_error_steps[boosting_step]=calculate_mean_error(calculate_errors(y_validation,predictions,sample_weight_validation,FAMILY_GAUSSIAN),sample_weight_validation);
989987
else if(validation_tuning_metric=="mae")
990988
validation_error_steps[boosting_step]=calculate_mean_error(calculate_absolute_errors(y_validation,predictions,sample_weight_validation),sample_weight_validation);
991989
else if(validation_tuning_metric=="rankability")
@@ -1236,7 +1234,6 @@ void APLRRegressor::cleanup_after_fit()
12361234
predictions_current.resize(0);
12371235
predictions_current_validation.resize(0);
12381236
neg_gradient_current.resize(0);
1239-
neg_gradient_nullmodel_errors.resize(0);
12401237
linear_predictor_update.resize(0);
12411238
linear_predictor_update_validation.resize(0);
12421239
distributed_terms.clear();

cpp/constants.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
#include <limits>
33

44
const double NAN_DOUBLE{ std::numeric_limits<double>::quiet_NaN() };
5-
const int MAX_ABS_EXPONENT_TO_APPLY_ON_LINEAR_PREDICTOR_IN_LOGIT_MODEL{std::min(16, std::numeric_limits<double>::max_exponent10)};
5+
const int MAX_ABS_EXPONENT_TO_APPLY_ON_LINEAR_PREDICTOR_IN_LOGIT_MODEL{std::min(16, std::numeric_limits<double>::max_exponent10)};
6+
const std::string FAMILY_GAUSSIAN{"gaussian"};

cpp/term.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,8 @@ void Term::discretize_data_by_bin()
505505

506506
void Term::estimate_split_point_on_discretized_data()
507507
{
508-
errors_initial=calculate_errors(negative_gradient_discretized,VectorXd::Constant(negative_gradient_discretized.size(),0.0),sample_weight_discretized);
508+
errors_initial=calculate_errors(negative_gradient_discretized,VectorXd::Constant(negative_gradient_discretized.size(),0.0),
509+
sample_weight_discretized,FAMILY_GAUSSIAN);
509510
error_initial=calculate_sum_error(errors_initial);
510511

511512
double split_point_temp;
@@ -640,7 +641,8 @@ void Term::estimate_coefficient_and_error_on_all_data()
640641
if(coefficient_adheres_to_monotonic_constraint())
641642
{
642643
VectorXd predictions{sorted_vectors.values_sorted*coefficient};
643-
split_point_search_errors_sum=calculate_sum_error(calculate_errors(sorted_vectors.negative_gradient_sorted,predictions,sorted_vectors.sample_weight_sorted))+error_where_given_terms_are_zero;
644+
split_point_search_errors_sum=calculate_sum_error(calculate_errors(sorted_vectors.negative_gradient_sorted,predictions,
645+
sorted_vectors.sample_weight_sorted,FAMILY_GAUSSIAN))+error_where_given_terms_are_zero;
644646
}
645647
else
646648
{

0 commit comments

Comments
 (0)