Skip to content

Commit 3db448f

Browse files
authored
Implement and use thread pool to run models concurrently (#178)
* Implement and use thread pool to run models concurrently * add comments
1 parent 0447574 commit 3db448f

File tree

6 files changed

+117
-7
lines changed

6 files changed

+117
-7
lines changed

Testing/WinMLRunnerTest/WinMLRunnerTest.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,36 @@ namespace WinMLRunnerTest
658658
}
659659
};
660660

661+
TEST_CLASS(ConcurrencyTest)
662+
{
663+
public:
664+
TEST_CLASS_INITIALIZE(SetupClass)
665+
{
666+
// Make test_folder_input folder before starting the tests
667+
std::string mkFolderCommand = "mkdir " + std::string(INPUT_FOLDER_PATH.begin(), INPUT_FOLDER_PATH.end());
668+
system(mkFolderCommand.c_str());
669+
670+
std::vector<std::string> models = { "SqueezeNet.onnx", "keras_Add_ImageNet_small.onnx" };
671+
672+
// Copy models from list to test_folder_input
673+
for (auto model : models)
674+
{
675+
std::string copyCommand = "Copy ";
676+
copyCommand += model;
677+
copyCommand += ' ' + std::string(INPUT_FOLDER_PATH.begin(), INPUT_FOLDER_PATH.end());
678+
system(copyCommand.c_str());
679+
}
680+
}
681+
682+
TEST_METHOD(RunFolder)
683+
{
684+
const std::wstring command = BuildCommand({
685+
EXE_PATH, L"-folder", INPUT_FOLDER_PATH, L"-ConcurrentLoad", L"-NumThreads", L"5"
686+
});
687+
Assert::AreEqual(S_OK, RunProc((wchar_t *)command.c_str()));
688+
}
689+
};
690+
661691
TEST_CLASS(OtherTests)
662692
{
663693
public:

Tools/WinMLRunner/WinMLRunnerScenarios.vcxproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
</ItemGroup>
2121
<ItemGroup>
2222
<ClInclude Include="src/Scenarios.h" />
23+
<ClInclude Include="src\ThreadPool.h" />
2324
</ItemGroup>
2425
<ItemGroup>
2526
<ClCompile Include="src/Concurrency.cpp" />
27+
<ClCompile Include="src\ThreadPool.cpp" />
2628
</ItemGroup>
2729
<PropertyGroup Label="Globals">
2830
<VCProjectVersion>15.0</VCProjectVersion>

Tools/WinMLRunner/WinMLRunnerScenarios.vcxproj.filters

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,16 @@
1818
<ClInclude Include="src/Scenarios.h">
1919
<Filter>Header Files</Filter>
2020
</ClInclude>
21+
<ClInclude Include="src\ThreadPool.h">
22+
<Filter>Header Files</Filter>
23+
</ClInclude>
2124
</ItemGroup>
2225
<ItemGroup>
2326
<ClCompile Include="src/Concurrency.cpp">
2427
<Filter>Source Files</Filter>
2528
</ClCompile>
29+
<ClCompile Include="src\ThreadPool.cpp">
30+
<Filter>Source Files</Filter>
31+
</ClCompile>
2632
</ItemGroup>
2733
</Project>

Tools/WinMLRunner/src/Concurrency.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
#include "Windows.h"
2-
#include "common.h"
31
#include <iostream>
42
#include <thread>
53
#include <regex>
64

5+
#include "Windows.h"
6+
#include "common.h"
7+
#include "ThreadPool.h"
8+
79
using namespace winrt;
810
using namespace winrt::Windows::AI::MachineLearning;
911

@@ -27,12 +29,16 @@ void load_model(const std::wstring &path, bool print_info)
2729
void ConcurrentLoadModel(const std::vector<std::wstring> &paths, unsigned num_threads,
2830
unsigned interval_milliseconds, bool print_info)
2931
{
30-
std::vector<std::thread> threads;
31-
unsigned threads_size = paths.size() > num_threads ? paths.size() : num_threads;
32-
for (unsigned i = 0; i < threads_size; i++)
32+
33+
ThreadPool pool(num_threads);
34+
// Creating enough threads to load all the models specified
35+
// If there is more than enough threads, some threads will concurrently load same models
36+
size_t threads_size = paths.size() > num_threads ? paths.size() : num_threads;
37+
std::vector<std::future<void>> output_futures;
38+
for (size_t i = 0; i < threads_size; i++)
3339
{
34-
threads.emplace_back(std::thread(load_model, std::ref(paths[i % paths.size()]), print_info));
3540
Sleep(interval_milliseconds);
41+
output_futures.push_back(pool.SubmitWork(load_model, std::ref(paths[i % paths.size()]), true));
3642
}
37-
std::for_each(threads.begin(), threads.end(), [](std::thread &th) { th.join(); });
43+
// TODO: read output values from load_model
3844
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "ThreadPool.h"
2+
#include <ctime>
3+
4+
ThreadPool::ThreadPool(unsigned int initial_pool_size): m_threads(), m_destruct_pool(false) {
5+
for (unsigned int i = 0; i < initial_pool_size; i++) {
6+
m_threads.emplace_back([this]() {
7+
while (true) {
8+
std::unique_lock<std::mutex> lock(m_mutex);
9+
// thread listening for event and acquire lock if event triggered
10+
m_cond_var.wait(lock, [this] { return m_destruct_pool || !m_work_queue.empty(); });
11+
if (!m_work_queue.empty()) {
12+
auto work = m_work_queue.front();
13+
m_work_queue.pop();
14+
lock.unlock();
15+
work();
16+
}
17+
else {
18+
// Work queue is empty but lock acquired
19+
// This means we are destructing the pool
20+
break;
21+
}
22+
}
23+
});
24+
}
25+
}
26+
27+
ThreadPool::~ThreadPool() {
28+
m_destruct_pool = true;
29+
m_cond_var.notify_all(); // notify destruction to threads
30+
for (auto &thread : m_threads) {
31+
thread.join();
32+
}
33+
}

Tools/WinMLRunner/src/ThreadPool.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#pragma once
2+
3+
#include <vector>
4+
#include <thread>
5+
#include <queue>
6+
#include <mutex>
7+
#include <future>
8+
9+
class ThreadPool {
10+
private:
11+
std::condition_variable m_cond_var;
12+
bool m_destruct_pool;
13+
std::mutex m_mutex;
14+
std::vector<std::thread> m_threads;
15+
std::queue<std::function<void()>> m_work_queue;
16+
17+
public:
18+
ThreadPool(unsigned int initial_pool_size);
19+
~ThreadPool();
20+
template <typename F, typename...Args>
21+
inline auto SubmitWork(F &&f, Args&&... args) -> std::future<decltype(f(args...))> {
22+
auto func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
23+
auto task = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
24+
{
25+
std::lock_guard<std::mutex> lock(m_mutex);
26+
// wrap packed task into a void return function type so that it can be stored in queue
27+
m_work_queue.push([task]() { (*task)(); });
28+
}
29+
30+
m_cond_var.notify_one(); // unblocks one of the waiting threads
31+
return task->get_future();
32+
}
33+
};

0 commit comments

Comments
 (0)