Skip to content

Commit d1ba994

Browse files
10.16.0
1 parent cc6af13 commit d1ba994

File tree

4 files changed

+187
-30
lines changed

4 files changed

+187
-30
lines changed

cpp/APLRRegressor.h

Lines changed: 91 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "functions.h"
1111
#include "term.h"
1212
#include "constants.h"
13+
#include "ThreadPool.h"
1314

1415
using namespace Eigen;
1516

@@ -82,6 +83,7 @@ class APLRRegressor
8283
bool round_robin_update_of_existing_terms;
8384
size_t term_to_update_in_this_boosting_step;
8485
size_t cores_to_use;
86+
std::unique_ptr<ThreadPool> thread_pool;
8587
bool stopped_early;
8688
std::vector<double> ridge_penalty_weights;
8789
double min_validation_error_for_current_fold;
@@ -273,6 +275,7 @@ class APLRRegressor
273275
size_t num_first_steps_with_linear_effects_only = 0, double penalty_for_non_linearity = 0.0, double penalty_for_interactions = 0.0,
274276
size_t max_terms = 0, double ridge_penalty = 0.0001, bool mean_bias_correction = false);
275277
APLRRegressor(const APLRRegressor &other);
278+
APLRRegressor &operator=(const APLRRegressor &other);
276279
~APLRRegressor();
277280
void fit(const MatrixXd &X, const VectorXd &y, const VectorXd &sample_weight = VectorXd(0), const std::vector<std::string> &X_names = {},
278281
const MatrixXi &cv_observations = MatrixXi(0, 0), const std::vector<size_t> &prioritized_predictors_indexes = {},
@@ -380,6 +383,74 @@ APLRRegressor::APLRRegressor(const APLRRegressor &other)
380383
{
381384
}
382385

386+
APLRRegressor &APLRRegressor::operator=(const APLRRegressor &other)
387+
{
388+
if (this == &other)
389+
{
390+
return *this;
391+
}
392+
393+
intercept = other.intercept;
394+
terms = other.terms;
395+
m = other.m;
396+
v = other.v;
397+
loss_function = other.loss_function;
398+
link_function = other.link_function;
399+
cv_folds = other.cv_folds;
400+
n_jobs = other.n_jobs;
401+
random_state = other.random_state;
402+
bins = other.bins;
403+
verbosity = other.verbosity;
404+
term_names = other.term_names;
405+
term_affiliations = other.term_affiliations;
406+
term_coefficients = other.term_coefficients;
407+
max_interaction_level = other.max_interaction_level;
408+
max_interactions = other.max_interactions;
409+
interactions_eligible = other.interactions_eligible;
410+
validation_error_steps = other.validation_error_steps;
411+
min_observations_in_split = other.min_observations_in_split;
412+
ineligible_boosting_steps_added = other.ineligible_boosting_steps_added;
413+
max_eligible_terms = other.max_eligible_terms;
414+
number_of_base_terms = other.number_of_base_terms;
415+
number_of_unique_term_affiliations = other.number_of_unique_term_affiliations;
416+
feature_importance = other.feature_importance;
417+
term_importance = other.term_importance;
418+
dispersion_parameter = other.dispersion_parameter;
419+
min_training_prediction_or_response = other.min_training_prediction_or_response;
420+
max_training_prediction_or_response = other.max_training_prediction_or_response;
421+
validation_tuning_metric = other.validation_tuning_metric;
422+
quantile = other.quantile;
423+
m_optimal = other.m_optimal;
424+
calculate_custom_validation_error_function = other.calculate_custom_validation_error_function;
425+
calculate_custom_loss_function = other.calculate_custom_loss_function;
426+
calculate_custom_negative_gradient_function = other.calculate_custom_negative_gradient_function;
427+
calculate_custom_transform_linear_predictor_to_predictions_function = other.calculate_custom_transform_linear_predictor_to_predictions_function;
428+
calculate_custom_differentiate_predictions_wrt_linear_predictor_function = other.calculate_custom_differentiate_predictions_wrt_linear_predictor_function;
429+
boosting_steps_before_interactions_are_allowed = other.boosting_steps_before_interactions_are_allowed;
430+
monotonic_constraints_ignore_interactions = other.monotonic_constraints_ignore_interactions;
431+
group_mse_by_prediction_bins = other.group_mse_by_prediction_bins;
432+
group_mse_cycle_min_obs_in_bin = other.group_mse_cycle_min_obs_in_bin;
433+
cv_error = other.cv_error;
434+
term_main_predictor_indexes = other.term_main_predictor_indexes;
435+
term_interaction_levels = other.term_interaction_levels;
436+
early_stopping_rounds = other.early_stopping_rounds;
437+
num_first_steps_with_linear_effects_only = other.num_first_steps_with_linear_effects_only;
438+
penalty_for_non_linearity = other.penalty_for_non_linearity;
439+
penalty_for_interactions = other.penalty_for_interactions;
440+
max_terms = other.max_terms;
441+
min_predictor_values_in_training = other.min_predictor_values_in_training;
442+
max_predictor_values_in_training = other.max_predictor_values_in_training;
443+
unique_term_affiliations = other.unique_term_affiliations;
444+
unique_term_affiliation_map = other.unique_term_affiliation_map;
445+
base_predictors_in_each_unique_term_affiliation = other.base_predictors_in_each_unique_term_affiliation;
446+
ridge_penalty = other.ridge_penalty;
447+
mean_bias_correction = other.mean_bias_correction;
448+
449+
thread_pool.reset();
450+
451+
return *this;
452+
}
453+
383454
APLRRegressor::~APLRRegressor()
384455
{
385456
}
@@ -442,6 +513,8 @@ void APLRRegressor::initialize_multithreading()
442513
cores_to_use = available_cores;
443514
else
444515
cores_to_use = std::min(n_jobs, available_cores);
516+
if (cores_to_use > 1)
517+
thread_pool = std::make_unique<ThreadPool>(cores_to_use);
445518
}
446519

447520
void APLRRegressor::preprocess_penalties()
@@ -1387,39 +1460,28 @@ std::vector<size_t> APLRRegressor::find_terms_eligible_current_indexes_for_a_bas
13871460

13881461
void APLRRegressor::estimate_split_point_for_each_term(std::vector<Term> &terms, std::vector<size_t> &terms_indexes)
13891462
{
1390-
bool multithreading{n_jobs != 1 && terms_indexes.size() > 1};
1463+
bool multithreading{cores_to_use > 1 && terms_indexes.size() > 1};
13911464

13921465
if (multithreading)
13931466
{
1394-
size_t num_threads{std::min(cores_to_use, terms_indexes.size())};
1395-
std::vector<std::thread> threads;
1396-
size_t chunk_size{(terms_indexes.size() + num_threads - 1) / num_threads};
1397-
1398-
for (size_t t = 0; t < num_threads; ++t)
1399-
{
1400-
threads.emplace_back([&, t]()
1401-
{
1402-
size_t start = t * chunk_size;
1403-
size_t end = std::min(start + chunk_size, terms_indexes.size());
1404-
for (size_t i = start; i < end; ++i)
1405-
{
1406-
terms[terms_indexes[i]].estimate_split_point(X_train, neg_gradient_current, sample_weight_train, bins,
1407-
predictor_learning_rates[terms[terms_indexes[i]].base_term],
1408-
predictor_min_observations_in_split[terms[terms_indexes[i]].base_term],
1409-
linear_effects_only_in_this_boosting_step,
1410-
predictor_penalties_for_non_linearity[terms[terms_indexes[i]].base_term],
1411-
predictor_penalties_for_interactions[terms[terms_indexes[i]].base_term],
1412-
ridge_penalty,
1413-
ridge_penalty_weights[terms[terms_indexes[i]].base_term]);
1414-
} });
1415-
}
1416-
1417-
for (auto &thread : threads)
1467+
std::vector<std::future<void>> results;
1468+
for (size_t i = 0; i < terms_indexes.size(); ++i)
14181469
{
1419-
if (thread.joinable())
1420-
{
1421-
thread.join();
1422-
}
1470+
results.emplace_back(
1471+
thread_pool->enqueue([&terms, &terms_indexes, i, this]
1472+
{ terms[terms_indexes[i]].estimate_split_point(
1473+
this->X_train, this->neg_gradient_current, this->sample_weight_train, this->bins,
1474+
this->predictor_learning_rates[terms[terms_indexes[i]].base_term],
1475+
this->predictor_min_observations_in_split[terms[terms_indexes[i]].base_term],
1476+
this->linear_effects_only_in_this_boosting_step,
1477+
this->predictor_penalties_for_non_linearity[terms[terms_indexes[i]].base_term],
1478+
this->predictor_penalties_for_interactions[terms[terms_indexes[i]].base_term],
1479+
this->ridge_penalty,
1480+
this->ridge_penalty_weights[terms[terms_indexes[i]].base_term]); }));
1481+
}
1482+
for (auto &&result : results)
1483+
{
1484+
result.get();
14231485
}
14241486
}
14251487
else

cpp/ThreadPool.h

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#pragma once
2+
3+
#include <vector>
4+
#include <queue>
5+
#include <thread>
6+
#include <mutex>
7+
#include <condition_variable>
8+
#include <functional>
9+
#include <future>
10+
11+
class ThreadPool
12+
{
13+
public:
14+
ThreadPool(size_t);
15+
template <class F, class... Args>
16+
auto enqueue(F &&f, Args &&...args)
17+
-> std::future<typename std::result_of<F(Args...)>::type>;
18+
~ThreadPool();
19+
20+
private:
21+
// need to keep track of threads so we can join them
22+
std::vector<std::thread> workers;
23+
// the task queue
24+
std::queue<std::function<void()>> tasks;
25+
26+
// synchronization
27+
std::mutex queue_mutex;
28+
std::condition_variable condition;
29+
bool stop;
30+
};
31+
32+
// the constructor just launches some amount of workers
33+
inline ThreadPool::ThreadPool(size_t threads)
34+
: stop(false)
35+
{
36+
for (size_t i = 0; i < threads; ++i)
37+
workers.emplace_back(
38+
[this]
39+
{
40+
for (;;)
41+
{
42+
std::function<void()> task;
43+
44+
{
45+
std::unique_lock<std::mutex> lock(this->queue_mutex);
46+
this->condition.wait(lock,
47+
[this]
48+
{ return this->stop || !this->tasks.empty(); });
49+
if (this->stop && this->tasks.empty())
50+
return;
51+
task = std::move(this->tasks.front());
52+
this->tasks.pop();
53+
}
54+
55+
task();
56+
}
57+
});
58+
}
59+
60+
// add new work item to the pool
61+
template <class F, class... Args>
62+
auto ThreadPool::enqueue(F &&f, Args &&...args)
63+
-> std::future<typename std::result_of<F(Args...)>::type>
64+
{
65+
using return_type = typename std::result_of<F(Args...)>::type;
66+
67+
auto task = std::make_shared<std::packaged_task<return_type()>>(
68+
std::bind(std::forward<F>(f), std::forward<Args>(args)...));
69+
70+
std::future<return_type> res = task->get_future();
71+
{
72+
std::unique_lock<std::mutex> lock(queue_mutex);
73+
74+
// don't allow enqueueing after stopping the pool
75+
if (stop)
76+
throw std::runtime_error("enqueue on stopped ThreadPool");
77+
78+
tasks.emplace([task]()
79+
{ (*task)(); });
80+
}
81+
condition.notify_one();
82+
return res;
83+
}
84+
85+
// the destructor joins all threads
86+
inline ThreadPool::~ThreadPool()
87+
{
88+
{
89+
std::unique_lock<std::mutex> lock(queue_mutex);
90+
stop = true;
91+
}
92+
condition.notify_all();
93+
for (std::thread &worker : workers)
94+
worker.join();
95+
}
Binary file not shown.

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
setuptools.setup(
3030
name="aplr",
31-
version="10.15.0",
31+
version="10.16.0",
3232
description="Automatic Piecewise Linear Regression",
3333
ext_modules=[sfc_module],
3434
author="Mathias von Ottenbreit",

0 commit comments

Comments
 (0)