diff --git a/example/example.cpp b/example/example.cpp index ea90694..9ed49ce 100644 --- a/example/example.cpp +++ b/example/example.cpp @@ -1,7 +1,11 @@ +// Concorrency.cpp : 此文件包含 "main" 函数。程序执行将在此处开始并结束。 +// + +#include "pch.h" #include #include -#include "../include/ThreadPool.h" +#include "ThreadPool.h" std::random_device rd; std::mt19937 mt(rd()); @@ -10,60 +14,61 @@ auto rnd = std::bind(dist, mt); void simulate_hard_computation() { - std::this_thread::sleep_for(std::chrono::milliseconds(2000 + rnd())); + std::this_thread::sleep_for(std::chrono::milliseconds(2000 + rnd())); } // Simple function that adds multiplies two numbers and prints the result void multiply(const int a, const int b) { - simulate_hard_computation(); - const int res = a * b; - std::cout << a << " * " << b << " = " << res << std::endl; + simulate_hard_computation(); + const int res = a * b; + std::cout << a << " * " << b << " = " << res << std::endl; } // Same as before but now we have an output parameter void multiply_output(int & out, const int a, const int b) { - simulate_hard_computation(); - out = a * b; - std::cout << a << " * " << b << " = " << out << std::endl; + //simulate_hard_computation(); + out = a * b; + std::cout << a << " * " << b << " = " << out << std::endl; } // Same as before but now we have an output parameter int multiply_return(const int a, const int b) { - simulate_hard_computation(); - const int res = a * b; - std::cout << a << " * " << b << " = " << res << std::endl; - return res; + //simulate_hard_computation(); + const int res = a * b; + std::cout << a << " * " << b << " = " << res << std::endl; + return res; } -void example() { - // Create pool with 3 threads - ThreadPool pool(3); +int main() { + // Create pool with 3 threads + ThreadPool pool(3,100); - // Initialize pool - pool.init(); + // Initialize pool + pool.init(); - // Submit (partial) multiplication table - for (int i = 1; i < 3; ++i) { - for (int j = 1; j < 10; ++j) { - pool.submit(multiply, i, j); - } - } + // Submit (partial) multiplication table + for (int i = 1; i < 10; ++i) { + for (int j = 1; j < 10; ++j) { + pool.submit(multiply, i, j); + } + } - // Submit function with output parameter passed by ref - int output_ref; - auto future1 = pool.submit(multiply_output, std::ref(output_ref), 5, 6); + // Submit function with output parameter passed by ref + int output_ref; + auto future1 = pool.submit(multiply_output, std::ref(output_ref), 5, 6); - // Wait for multiplication output to finish - future1.get(); - std::cout << "Last operation result is equals to " << output_ref << std::endl; + // Wait for multiplication output to finish + future1.get(); + std::cout << "Last operation result is equals to " << output_ref << std::endl; - // Submit function with return parameter - auto future2 = pool.submit(multiply_return, 5, 3); + // Submit function with return parameter + auto future2 = pool.submit(multiply_return, 5, 3); - // Wait for multiplication output to finish - int res = future2.get(); - std::cout << "Last operation result is equals to " << res << std::endl; - - pool.shutdown(); + // Wait for multiplication output to finish + int res = future2.get(); + std::cout << "Last operation result is equals to " << res << std::endl; + //std::this_thread::sleep_for(std::chrono::milliseconds(10000 + rnd())); + getchar(); + pool.shutdown(); } diff --git a/include/ThreadPool.h b/include/ThreadPool.h index b7bf84d..6c72716 100644 --- a/include/ThreadPool.h +++ b/include/ThreadPool.h @@ -7,93 +7,171 @@ #include #include #include +#include #include "SafeQueue.h" class ThreadPool { private: - class ThreadWorker { - private: - int m_id; - ThreadPool * m_pool; - public: - ThreadWorker(ThreadPool * pool, const int id) - : m_pool(pool), m_id(id) { - } - - void operator()() { - std::function func; - bool dequeued; - while (!m_pool->m_shutdown) { - { - std::unique_lock lock(m_pool->m_conditional_mutex); - if (m_pool->m_queue.empty()) { - m_pool->m_conditional_lock.wait(lock); - } - dequeued = m_pool->m_queue.dequeue(func); - } - if (dequeued) { - func(); - } - } - } - }; - - bool m_shutdown; - SafeQueue> m_queue; - std::vector m_threads; - std::mutex m_conditional_mutex; - std::condition_variable m_conditional_lock; + class ThreadWorker { + private: + int m_id; + ThreadPool * m_pool; + public: + ThreadWorker(ThreadPool * pool, const int id) + : m_pool(pool), m_id(id) { + } + + void operator()() { + std::function func; + bool dequeued; + + while (!m_pool->m_shutdown) { + { + std::unique_lock lock(m_pool->m_conditional_mutex); + + while (m_pool->m_queue.empty() && !m_pool->m_shutdown) + { + + if (m_pool->wait_exit_thr_num > 0) + { + m_pool->wait_exit_thr_num--; + + if (m_pool->live_thr_num > m_pool->min_thr_num) + { + m_pool->live_thr_num--; + return; + } + } + else + { + m_pool->m_conditional_lock.wait(lock); + } + } + } + + dequeued = m_pool->m_queue.dequeue(func); + + if (dequeued) { + m_pool->busy_thr_num++; + func(); + m_pool->busy_thr_num--; + } + } + + } + }; + + // + void adjust_thread(void) + { + while (!m_shutdown) + { + std::this_thread::sleep_for(period); + std::unique_lock lock(m_conditional_mutex); + int queue_size = m_queue.size(); + int live_thr_num = this->live_thr_num; + lock.unlock(); + int busy_thr_num = this->busy_thr_num; + + if (queue_size >= min_wait_task_num && live_thr_num < max_thr_num) { + lock.lock(); + int add = 0; + for (int i = 0; i < max_thr_num && add < default_thread_vary + && live_thr_num < max_thr_num; i++) { + if (m_threads[i].get_id() == std::thread::id()) { + m_threads[i] = std::thread(ThreadWorker(this, i)); + add++; + this->live_thr_num++; + } + } + lock.unlock(); + } + + if ((busy_thr_num * 2) < live_thr_num && live_thr_num > min_thr_num) { + lock.lock(); + wait_exit_thr_num = default_thread_vary; + lock.unlock(); + m_conditional_lock.notify_all(); + } + + } + } + + const int min_wait_task_num = 10; + const int default_thread_vary = 10; + const int min_thr_num; + //Maximum number of threads + const int max_thr_num; + int live_thr_num; + int wait_exit_thr_num; + //Number of busy threads + std::atomic busy_thr_num; + bool m_shutdown; + SafeQueue> m_queue; + std::vector m_threads; + std::mutex m_conditional_mutex; + std::condition_variable m_conditional_lock; + std::chrono::seconds period{5}; + std::thread adjustthread; public: - ThreadPool(const int n_threads) - : m_threads(std::vector(n_threads)), m_shutdown(false) { - } - - ThreadPool(const ThreadPool &) = delete; - ThreadPool(ThreadPool &&) = delete; - - ThreadPool & operator=(const ThreadPool &) = delete; - ThreadPool & operator=(ThreadPool &&) = delete; - - // Inits thread pool - void init() { - for (int i = 0; i < m_threads.size(); ++i) { - m_threads[i] = std::thread(ThreadWorker(this, i)); - } - } - - // Waits until threads finish their current task and shutdowns the pool - void shutdown() { - m_shutdown = true; - m_conditional_lock.notify_all(); - - for (int i = 0; i < m_threads.size(); ++i) { - if(m_threads[i].joinable()) { - m_threads[i].join(); - } - } - } - - // Submit a function to be executed asynchronously by the pool - template - auto submit(F&& f, Args&&... args) -> std::future { - // Create a function with bounded parameters ready to execute - std::function func = std::bind(std::forward(f), std::forward(args)...); - // Encapsulate it into a shared ptr in order to be able to copy construct / assign - auto task_ptr = std::make_shared>(func); - - // Wrap packaged task into void function - std::function wrapper_func = [task_ptr]() { - (*task_ptr)(); - }; - - // Enqueue generic wrapper function - m_queue.enqueue(wrapper_func); - - // Wake up one thread if its waiting - m_conditional_lock.notify_one(); - - // Return future from promise - return task_ptr->get_future(); - } + ThreadPool(const int min_thr_num,const int max_thr_num) + : m_threads(std::vector(max_thr_num)), min_thr_num(min_thr_num), + max_thr_num(max_thr_num), live_thr_num(min_thr_num), m_shutdown(false), wait_exit_thr_num(0){ + } + + ThreadPool(const ThreadPool &) = delete; + ThreadPool(ThreadPool &&) = delete; + + ThreadPool & operator=(const ThreadPool &) = delete; + ThreadPool & operator=(ThreadPool &&) = delete; + + // Inits thread pool + void init() { + for (int i = 0; i < min_thr_num; ++i) { + m_threads[i] = std::thread(ThreadWorker(this, i)); + } + adjustthread = std::move(std::thread(&ThreadPool::adjust_thread, this)); + } + + // Waits until threads finish their current task and shutdowns the pool + void shutdown() { + m_shutdown = true; + if (adjustthread.joinable()) { + adjustthread.join(); + } + + m_conditional_lock.notify_all(); + + for (size_t i = 0; i < m_threads.size(); ++i) { + if (m_threads[i].joinable()) { + m_threads[i].join(); + } + } + } + + // Submit a function to be executed asynchronously by the pool + template + auto submit(F&& f, Args&&... args) -> std::future { + // Create a function with bounded parameters ready to execute + std::function func = std::bind(std::forward(f), std::forward(args)...); + // Encapsulate it into a shared ptr in order to be able to copy construct / assign + auto task_ptr = std::make_shared>(func); + + // Wrap packaged task into void function + std::function wrapper_func = [task_ptr]() { + (*task_ptr)(); + }; + + // Enqueue generic wrapper function + m_queue.enqueue(wrapper_func); + + // Wake up one thread if its waiting + m_conditional_lock.notify_one(); + + // Return future from promise + return task_ptr->get_future(); + } + }; +