@@ -66,7 +66,7 @@ class APLRRegressor
6666 const std::vector<std::vector<size_t >> &interaction_constraints);
6767 bool check_if_base_term_has_only_one_unique_value (size_t base_term);
6868 void add_term_to_terms_eligible_current (Term &term);
69- VectorXd calculate_neg_gradient_current (const VectorXd &sample_weight_train );
69+ VectorXd calculate_neg_gradient_current ();
7070 void execute_boosting_steps ();
7171 void execute_boosting_step (size_t boosting_step);
7272 std::vector<size_t > find_terms_eligible_current_indexes_for_a_base_term (size_t base_term);
@@ -109,7 +109,7 @@ class APLRRegressor
109109 void throw_error_if_sample_weight_contains_invalid_values (const VectorXd &y, const VectorXd &sample_weight);
110110 void throw_error_if_response_is_not_between_0_and_1 (const VectorXd &y, const std::string &error_message);
111111 void throw_error_if_vector_contains_negative_values (const VectorXd &y, const std::string &error_message);
112- void throw_error_if_response_is_not_greater_than_zero (const VectorXd &y, const std::string &error_message);
112+ void throw_error_if_vector_contains_non_positive_values (const VectorXd &y, const std::string &error_message);
113113 void throw_error_if_dispersion_parameter_is_invalid ();
114114 VectorXd differentiate_predictions_wrt_linear_predictor ();
115115 void scale_response_if_using_log_link_function ();
@@ -366,9 +366,9 @@ void APLRRegressor::validate_input_to_fit(const MatrixXd &X, const VectorXd &y,
366366 throw_error_if_interaction_constraints_has_invalid_indexes (X, interaction_constraints);
367367 throw_error_if_response_contains_invalid_values (y);
368368 throw_error_if_sample_weight_contains_invalid_values (y, sample_weight);
369- bool group_is_of_incorrect_size{loss_function == " group_mse" && group.rows () != y.rows ()};
369+ bool group_is_of_incorrect_size{( loss_function == " group_mse" || validation_tuning_metric == " group_mse " ) && group.rows () != y.rows ()};
370370 if (group_is_of_incorrect_size)
371- throw std::runtime_error (" When loss_function is group_mse then y and group must have the same number of rows." );
371+ throw std::runtime_error (" When loss_function or validation_tuning_metric is group_mse then y and group must have the same number of rows." );
372372 bool other_data_is_provided{other_data.size () > 0 };
373373 if (other_data_is_provided)
374374 {
@@ -439,7 +439,7 @@ void APLRRegressor::throw_error_if_response_contains_invalid_values(const Vector
439439 error_message = " Response values for the " + loss_function + " loss_function when dispersion_parameter>2 must be greater than zero." ;
440440 else
441441 error_message = " Response values for the " + loss_function + " loss_function must be greater than zero." ;
442- throw_error_if_response_is_not_greater_than_zero (y, error_message);
442+ throw_error_if_vector_contains_non_positive_values (y, error_message);
443443 }
444444 else if (link_function == " log" || loss_function == " poisson" || loss_function == " negative_binomial" || loss_function == " weibull" || (loss_function == " tweedie" && std::isless (dispersion_parameter, 2 ) && std::isgreater (dispersion_parameter, 1 )))
445445 {
@@ -471,7 +471,7 @@ void APLRRegressor::throw_error_if_vector_contains_negative_values(const VectorX
471471 throw std::runtime_error (error_message);
472472}
473473
474- void APLRRegressor::throw_error_if_response_is_not_greater_than_zero (const VectorXd &y, const std::string &error_message)
474+ void APLRRegressor::throw_error_if_vector_contains_non_positive_values (const VectorXd &y, const std::string &error_message)
475475{
476476 bool response_is_not_greater_than_zero{(y.array () <= 0.0 ).any ()};
477477 if (response_is_not_greater_than_zero)
@@ -485,10 +485,7 @@ void APLRRegressor::throw_error_if_sample_weight_contains_invalid_values(const V
485485 {
486486 if (sample_weight.size () != y.size ())
487487 throw std::runtime_error (" sample_weight must have 0 or as many rows as X and y." );
488- throw_error_if_vector_contains_negative_values (sample_weight, " sample_weight cannot contain negative values." );
489- bool sum_is_zero{sample_weight.sum () == 0 };
490- if (sum_is_zero)
491- throw std::runtime_error (" sample_weight cannot sum to zero." );
488+ throw_error_if_vector_contains_non_positive_values (sample_weight, " all sample_weight values must be greater than zero." );
492489 }
493490}
494491
@@ -705,7 +702,7 @@ void APLRRegressor::add_term_to_terms_eligible_current(Term &term)
705702 terms_eligible_current.push_back (term);
706703}
707704
708- VectorXd APLRRegressor::calculate_neg_gradient_current (const VectorXd &sample_weight_train )
705+ VectorXd APLRRegressor::calculate_neg_gradient_current ()
709706{
710707 VectorXd output;
711708 if (loss_function == " mse" )
@@ -720,17 +717,29 @@ VectorXd APLRRegressor::calculate_neg_gradient_current(const VectorXd &sample_we
720717 output = (y_train.array () - predictions_current.array ()).array () * predictions_current.array ().pow (-dispersion_parameter);
721718 else if (loss_function == " group_mse" )
722719 {
723- GroupData group_residuals_and_count{calculate_group_errors_and_count (y_train, predictions_current, group_train, unique_groups_train)};
720+ GroupData group_residuals_and_count{calculate_group_errors_and_count (y_train, predictions_current, group_train, unique_groups_train,
721+ sample_weight_train)};
724722
725723 for (int unique_group_value : unique_groups_train)
726724 {
727725 group_residuals_and_count.error [unique_group_value] /= group_residuals_and_count.count [unique_group_value];
728726 }
729727
730728 output = VectorXd (y_train.rows ());
731- for (Eigen::Index i = 0 ; i < y_train.size (); ++i)
729+ bool sample_weight_is_provided{sample_weight_train.size () > 0 };
730+ if (sample_weight_is_provided)
732731 {
733- output[i] = group_residuals_and_count.error [group_train[i]];
732+ for (Eigen::Index i = 0 ; i < y_train.size (); ++i)
733+ {
734+ output[i] = group_residuals_and_count.error [group_train[i]] * sample_weight_train[i];
735+ }
736+ }
737+ else
738+ {
739+ for (Eigen::Index i = 0 ; i < y_train.size (); ++i)
740+ {
741+ output[i] = group_residuals_and_count.error [group_train[i]];
742+ }
734743 }
735744 }
736745 else if (loss_function == " mae" )
@@ -892,7 +901,7 @@ void APLRRegressor::update_linear_predictor_and_predictions()
892901
893902void APLRRegressor::update_gradient_and_errors ()
894903{
895- neg_gradient_current = calculate_neg_gradient_current (sample_weight_train );
904+ neg_gradient_current = calculate_neg_gradient_current ();
896905 neg_gradient_nullmodel_errors_sum = calculate_sum_error (calculate_errors (neg_gradient_current, linear_predictor_null_model, sample_weight_train, MSE_LOSS_FUNCTION));
897906}
898907
0 commit comments