@@ -31,7 +31,6 @@ class APLRRegressor
3131 VectorXd predictions_current;
3232 VectorXd predictions_current_validation;
3333 VectorXd neg_gradient_current;
34- VectorXd neg_gradient_nullmodel_errors;
3534 double neg_gradient_nullmodel_errors_sum;
3635 size_t best_term_index;
3736 VectorXd linear_predictor_update;
@@ -630,8 +629,7 @@ void APLRRegressor::update_linear_predictor_and_predictions()
630629void APLRRegressor::update_gradient_and_errors ()
631630{
632631 neg_gradient_current=calculate_neg_gradient_current ();
633- neg_gradient_nullmodel_errors=calculate_errors (neg_gradient_current,linear_predictor_null_model,sample_weight_train);
634- neg_gradient_nullmodel_errors_sum=calculate_sum_error (neg_gradient_nullmodel_errors);
632+ neg_gradient_nullmodel_errors_sum=calculate_sum_error (calculate_errors (neg_gradient_current,linear_predictor_null_model,sample_weight_train,FAMILY_GAUSSIAN));
635633}
636634
637635std::vector<size_t > APLRRegressor::find_terms_eligible_current_indexes_for_a_base_term (size_t base_term)
@@ -908,7 +906,7 @@ void APLRRegressor::select_the_best_term_and_update_errors(size_t boosting_step,
908906
909907 linear_predictor_update=terms_eligible_current[best_term_index].calculate_contribution_to_linear_predictor (X_train);
910908 linear_predictor_update_validation=terms_eligible_current[best_term_index].calculate_contribution_to_linear_predictor (X_validation);
911- double error_after_updating_term=calculate_sum_error (calculate_errors (neg_gradient_current,linear_predictor_update,sample_weight_train));
909+ double error_after_updating_term=calculate_sum_error (calculate_errors (neg_gradient_current,linear_predictor_update,sample_weight_train,FAMILY_GAUSSIAN ));
912910 bool no_improvement{std::isgreaterequal (error_after_updating_term,neg_gradient_nullmodel_errors_sum)};
913911 if (no_improvement)
914912 {
@@ -985,7 +983,7 @@ void APLRRegressor::calculate_validation_error(size_t boosting_step, const Vecto
985983 if (validation_tuning_metric==" default" )
986984 validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,predictions,sample_weight_validation,family,tweedie_power),sample_weight_validation);
987985 else if (validation_tuning_metric==" mse" )
988- validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,predictions,sample_weight_validation),sample_weight_validation);
986+ validation_error_steps[boosting_step]=calculate_mean_error (calculate_errors (y_validation,predictions,sample_weight_validation,FAMILY_GAUSSIAN ),sample_weight_validation);
989987 else if (validation_tuning_metric==" mae" )
990988 validation_error_steps[boosting_step]=calculate_mean_error (calculate_absolute_errors (y_validation,predictions,sample_weight_validation),sample_weight_validation);
991989 else if (validation_tuning_metric==" rankability" )
@@ -1236,7 +1234,6 @@ void APLRRegressor::cleanup_after_fit()
12361234 predictions_current.resize (0 );
12371235 predictions_current_validation.resize (0 );
12381236 neg_gradient_current.resize (0 );
1239- neg_gradient_nullmodel_errors.resize (0 );
12401237 linear_predictor_update.resize (0 );
12411238 linear_predictor_update_validation.resize (0 );
12421239 distributed_terms.clear ();
0 commit comments