@@ -2730,7 +2730,7 @@ MatrixXd APLRRegressor::get_unique_term_affiliation_shape(const std::string &uni
27302730 {
27312731 split_points_in_each_predictor[i] = compute_split_points (base_predictors_in_each_unique_term_affiliation[unique_term_affiliation_index][i], relevant_term_indexes);
27322732
2733- if (num_predictors_used_in_the_affiliation > 1 && additional_points > 0 )
2733+ if (num_predictors_used_in_the_affiliation > 1 && additional_points > 0 && !split_points_in_each_predictor[i]. empty () )
27342734 {
27352735 double min_val = *std::min_element (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ());
27362736 double max_val = *std::max_element (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ());
@@ -2741,9 +2741,9 @@ MatrixXd APLRRegressor::get_unique_term_affiliation_shape(const std::string &uni
27412741 double val = min_val + (max_val - min_val) * j / (additional_points + 1 );
27422742 interpolated.push_back (val);
27432743 }
2744+ split_points_in_each_predictor[i].reserve (split_points_in_each_predictor[i].size () + additional_points);
27442745 split_points_in_each_predictor[i].insert (split_points_in_each_predictor[i].end (), interpolated.begin (), interpolated.end ());
2745- std::sort (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ());
2746- split_points_in_each_predictor[i].erase (std::unique (split_points_in_each_predictor[i].begin (), split_points_in_each_predictor[i].end ()), split_points_in_each_predictor[i].end ());
2746+ split_points_in_each_predictor[i] = remove_duplicate_elements_from_vector (split_points_in_each_predictor[i]);
27472747 }
27482748 }
27492749
@@ -2762,6 +2762,7 @@ MatrixXd APLRRegressor::get_unique_term_affiliation_shape(const std::string &uni
27622762 {
27632763 size_t current_num_observations = split_points.size ();
27642764 size_t num_observations_to_keep = std::round (factor * std::sqrt (current_num_observations));
2765+ num_observations_to_keep = std::max<size_t >(1 , num_observations_to_keep);
27652766 if (current_num_observations > num_observations_to_keep)
27662767 {
27672768 std::shuffle (split_points.begin (), split_points.end (), seed);
0 commit comments