|
10 | 10 | #include "functions.h" |
11 | 11 | #include "term.h" |
12 | 12 | #include "constants.h" |
| 13 | +#include "ThreadPool.h" |
13 | 14 |
|
14 | 15 | using namespace Eigen; |
15 | 16 |
|
@@ -82,6 +83,7 @@ class APLRRegressor |
82 | 83 | bool round_robin_update_of_existing_terms; |
83 | 84 | size_t term_to_update_in_this_boosting_step; |
84 | 85 | size_t cores_to_use; |
| 86 | + std::unique_ptr<ThreadPool> thread_pool; |
85 | 87 | bool stopped_early; |
86 | 88 | std::vector<double> ridge_penalty_weights; |
87 | 89 | double min_validation_error_for_current_fold; |
@@ -273,6 +275,7 @@ class APLRRegressor |
273 | 275 | size_t num_first_steps_with_linear_effects_only = 0, double penalty_for_non_linearity = 0.0, double penalty_for_interactions = 0.0, |
274 | 276 | size_t max_terms = 0, double ridge_penalty = 0.0001, bool mean_bias_correction = false); |
275 | 277 | APLRRegressor(const APLRRegressor &other); |
| 278 | + APLRRegressor &operator=(const APLRRegressor &other); |
276 | 279 | ~APLRRegressor(); |
277 | 280 | void fit(const MatrixXd &X, const VectorXd &y, const VectorXd &sample_weight = VectorXd(0), const std::vector<std::string> &X_names = {}, |
278 | 281 | const MatrixXi &cv_observations = MatrixXi(0, 0), const std::vector<size_t> &prioritized_predictors_indexes = {}, |
@@ -380,6 +383,74 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other) |
380 | 383 | { |
381 | 384 | } |
382 | 385 |
|
| 386 | +APLRRegressor &APLRRegressor::operator=(const APLRRegressor &other) |
| 387 | +{ |
| 388 | + if (this == &other) |
| 389 | + { |
| 390 | + return *this; |
| 391 | + } |
| 392 | + |
| 393 | + intercept = other.intercept; |
| 394 | + terms = other.terms; |
| 395 | + m = other.m; |
| 396 | + v = other.v; |
| 397 | + loss_function = other.loss_function; |
| 398 | + link_function = other.link_function; |
| 399 | + cv_folds = other.cv_folds; |
| 400 | + n_jobs = other.n_jobs; |
| 401 | + random_state = other.random_state; |
| 402 | + bins = other.bins; |
| 403 | + verbosity = other.verbosity; |
| 404 | + term_names = other.term_names; |
| 405 | + term_affiliations = other.term_affiliations; |
| 406 | + term_coefficients = other.term_coefficients; |
| 407 | + max_interaction_level = other.max_interaction_level; |
| 408 | + max_interactions = other.max_interactions; |
| 409 | + interactions_eligible = other.interactions_eligible; |
| 410 | + validation_error_steps = other.validation_error_steps; |
| 411 | + min_observations_in_split = other.min_observations_in_split; |
| 412 | + ineligible_boosting_steps_added = other.ineligible_boosting_steps_added; |
| 413 | + max_eligible_terms = other.max_eligible_terms; |
| 414 | + number_of_base_terms = other.number_of_base_terms; |
| 415 | + number_of_unique_term_affiliations = other.number_of_unique_term_affiliations; |
| 416 | + feature_importance = other.feature_importance; |
| 417 | + term_importance = other.term_importance; |
| 418 | + dispersion_parameter = other.dispersion_parameter; |
| 419 | + min_training_prediction_or_response = other.min_training_prediction_or_response; |
| 420 | + max_training_prediction_or_response = other.max_training_prediction_or_response; |
| 421 | + validation_tuning_metric = other.validation_tuning_metric; |
| 422 | + quantile = other.quantile; |
| 423 | + m_optimal = other.m_optimal; |
| 424 | + calculate_custom_validation_error_function = other.calculate_custom_validation_error_function; |
| 425 | + calculate_custom_loss_function = other.calculate_custom_loss_function; |
| 426 | + calculate_custom_negative_gradient_function = other.calculate_custom_negative_gradient_function; |
| 427 | + calculate_custom_transform_linear_predictor_to_predictions_function = other.calculate_custom_transform_linear_predictor_to_predictions_function; |
| 428 | + calculate_custom_differentiate_predictions_wrt_linear_predictor_function = other.calculate_custom_differentiate_predictions_wrt_linear_predictor_function; |
| 429 | + boosting_steps_before_interactions_are_allowed = other.boosting_steps_before_interactions_are_allowed; |
| 430 | + monotonic_constraints_ignore_interactions = other.monotonic_constraints_ignore_interactions; |
| 431 | + group_mse_by_prediction_bins = other.group_mse_by_prediction_bins; |
| 432 | + group_mse_cycle_min_obs_in_bin = other.group_mse_cycle_min_obs_in_bin; |
| 433 | + cv_error = other.cv_error; |
| 434 | + term_main_predictor_indexes = other.term_main_predictor_indexes; |
| 435 | + term_interaction_levels = other.term_interaction_levels; |
| 436 | + early_stopping_rounds = other.early_stopping_rounds; |
| 437 | + num_first_steps_with_linear_effects_only = other.num_first_steps_with_linear_effects_only; |
| 438 | + penalty_for_non_linearity = other.penalty_for_non_linearity; |
| 439 | + penalty_for_interactions = other.penalty_for_interactions; |
| 440 | + max_terms = other.max_terms; |
| 441 | + min_predictor_values_in_training = other.min_predictor_values_in_training; |
| 442 | + max_predictor_values_in_training = other.max_predictor_values_in_training; |
| 443 | + unique_term_affiliations = other.unique_term_affiliations; |
| 444 | + unique_term_affiliation_map = other.unique_term_affiliation_map; |
| 445 | + base_predictors_in_each_unique_term_affiliation = other.base_predictors_in_each_unique_term_affiliation; |
| 446 | + ridge_penalty = other.ridge_penalty; |
| 447 | + mean_bias_correction = other.mean_bias_correction; |
| 448 | + |
| 449 | + thread_pool.reset(); |
| 450 | + |
| 451 | + return *this; |
| 452 | +} |
| 453 | + |
383 | 454 | APLRRegressor::~APLRRegressor() |
384 | 455 | { |
385 | 456 | } |
@@ -442,6 +513,8 @@ void APLRRegressor::initialize_multithreading() |
442 | 513 | cores_to_use = available_cores; |
443 | 514 | else |
444 | 515 | cores_to_use = std::min(n_jobs, available_cores); |
| 516 | + if (cores_to_use > 1) |
| 517 | + thread_pool = std::make_unique<ThreadPool>(cores_to_use); |
445 | 518 | } |
446 | 519 |
|
447 | 520 | void APLRRegressor::preprocess_penalties() |
@@ -1387,39 +1460,28 @@ std::vector<size_t> APLRRegressor::find_terms_eligible_current_indexes_for_a_bas |
1387 | 1460 |
|
1388 | 1461 | void APLRRegressor::estimate_split_point_for_each_term(std::vector<Term> &terms, std::vector<size_t> &terms_indexes) |
1389 | 1462 | { |
1390 | | - bool multithreading{n_jobs != 1 && terms_indexes.size() > 1}; |
| 1463 | + bool multithreading{cores_to_use > 1 && terms_indexes.size() > 1}; |
1391 | 1464 |
|
1392 | 1465 | if (multithreading) |
1393 | 1466 | { |
1394 | | - size_t num_threads{std::min(cores_to_use, terms_indexes.size())}; |
1395 | | - std::vector<std::thread> threads; |
1396 | | - size_t chunk_size{(terms_indexes.size() + num_threads - 1) / num_threads}; |
1397 | | - |
1398 | | - for (size_t t = 0; t < num_threads; ++t) |
1399 | | - { |
1400 | | - threads.emplace_back([&, t]() |
1401 | | - { |
1402 | | - size_t start = t * chunk_size; |
1403 | | - size_t end = std::min(start + chunk_size, terms_indexes.size()); |
1404 | | - for (size_t i = start; i < end; ++i) |
1405 | | - { |
1406 | | - terms[terms_indexes[i]].estimate_split_point(X_train, neg_gradient_current, sample_weight_train, bins, |
1407 | | - predictor_learning_rates[terms[terms_indexes[i]].base_term], |
1408 | | - predictor_min_observations_in_split[terms[terms_indexes[i]].base_term], |
1409 | | - linear_effects_only_in_this_boosting_step, |
1410 | | - predictor_penalties_for_non_linearity[terms[terms_indexes[i]].base_term], |
1411 | | - predictor_penalties_for_interactions[terms[terms_indexes[i]].base_term], |
1412 | | - ridge_penalty, |
1413 | | - ridge_penalty_weights[terms[terms_indexes[i]].base_term]); |
1414 | | - } }); |
1415 | | - } |
1416 | | - |
1417 | | - for (auto &thread : threads) |
| 1467 | + std::vector<std::future<void>> results; |
| 1468 | + for (size_t i = 0; i < terms_indexes.size(); ++i) |
1418 | 1469 | { |
1419 | | - if (thread.joinable()) |
1420 | | - { |
1421 | | - thread.join(); |
1422 | | - } |
| 1470 | + results.emplace_back( |
| 1471 | + thread_pool->enqueue([&terms, &terms_indexes, i, this] |
| 1472 | + { terms[terms_indexes[i]].estimate_split_point( |
| 1473 | + this->X_train, this->neg_gradient_current, this->sample_weight_train, this->bins, |
| 1474 | + this->predictor_learning_rates[terms[terms_indexes[i]].base_term], |
| 1475 | + this->predictor_min_observations_in_split[terms[terms_indexes[i]].base_term], |
| 1476 | + this->linear_effects_only_in_this_boosting_step, |
| 1477 | + this->predictor_penalties_for_non_linearity[terms[terms_indexes[i]].base_term], |
| 1478 | + this->predictor_penalties_for_interactions[terms[terms_indexes[i]].base_term], |
| 1479 | + this->ridge_penalty, |
| 1480 | + this->ridge_penalty_weights[terms[terms_indexes[i]].base_term]); })); |
| 1481 | + } |
| 1482 | + for (auto &&result : results) |
| 1483 | + { |
| 1484 | + result.get(); |
1423 | 1485 | } |
1424 | 1486 | } |
1425 | 1487 | else |
|
0 commit comments