Skip to content

Commit 3421ec1

Browse files
authored
Add Threadpool::TrySimpleParallelFor (microsoft#3759)
* Add TrySimpleParallerFor so that there's a path with OpenMP awareness for SimpleParallelFor. Makes it consistent with [Try]BatchParallelFor and [Try]ParallelFor. Update TopK to check for the number of threads better, and to use TrySimpleParallelFor. * Update doco to mention TrySimpleParallelFor
1 parent b9a5ed1 commit 3421ec1

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

docs/NotesOnThreading.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
This document is intended for ORT developers.
44

55
ORT allows the usage of either OpenMP or non-OpenMP (ORT) threads for execution. Threadpool management
6-
is abstracted behind: (1) ThreadPool class in threadpool.h and (2) functions in thread_utils.h.
6+
is abstracted behind: (1) ThreadPool class in [threadpool.h](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/platform/threadpool.h) and (2) functions in [thread_utils.h](https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/core/util/thread_utils.h).
77

88
When developing an op, please use these abstractions to parallelize your code. These abstractions centralize 2 things.
99
When OpenMP is enabled, they resort to using OpenMP. When OpenMP is disabled they resort to sequential execution if the threadpool ptr is NULL or schedule the tasks on the threadpool otherwise.
1010

11-
Examples of these abstractions are: (threadpool.h has more documentation for these)
11+
Examples of these abstractions are: ([threadpool.h](https://github.com/microsoft/onnxruntime/blob/master/include/onnxruntime/core/platform/threadpool.h) has more documentation for these)
1212
* TryBatchParallelFor
1313
* TryParallelFor
14+
* TrySimpleParallelFor
1415
* static version of NumThreads
1516

1617
**Please do not write #ifdef pragma omp in operator code**.

include/onnxruntime/core/platform/threadpool.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,11 @@ class ThreadPool {
190190

191191
// Similar to ParallelFor above, but takes the specified scheduling strategy
192192
// into account.
193-
void
194-
ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params,
195-
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);
193+
void ParallelFor(std::ptrdiff_t total, const SchedulingParams& scheduling_params,
194+
const std::function<void(std::ptrdiff_t, std::ptrdiff_t)>& fn);
196195

197-
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total, const SchedulingParams& scheduling_params,
196+
static void TryParallelFor(concurrency::ThreadPool* tp, std::ptrdiff_t total,
197+
const SchedulingParams& scheduling_params,
198198
const std::function<void(std::ptrdiff_t first, std::ptrdiff_t last)>& fn) {
199199
#ifdef _OPENMP
200200
ORT_UNUSED_PARAMETER(scheduling_params);
@@ -216,7 +216,7 @@ class ThreadPool {
216216
}
217217
tp->ParallelFor(total, scheduling_params, fn);
218218
#endif
219-
} // namespace concurrency
219+
}
220220

221221
// Prefer using this API to get the number of threads unless you know what you're doing.
222222
// This API takes into account if openmp is enabled/disabled and if the thread pool ptr is nullptr.
@@ -236,7 +236,27 @@ class ThreadPool {
236236

237237
// Directly schedule the 'total' tasks to the underlying threadpool, without
238238
// cutting them by halves
239-
void SimpleParallelFor(std::ptrdiff_t total, std::function<void(std::ptrdiff_t)> fn);
239+
void SimpleParallelFor(std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn);
240+
241+
inline static void TrySimpleParallelFor(ThreadPool* tp, std::ptrdiff_t total,
242+
const std::function<void(std::ptrdiff_t)>& fn) {
243+
#ifdef _OPENMP
244+
ORT_UNUSED_PARAMETER(tp);
245+
#pragma omp parallel for
246+
for (std::ptrdiff_t i = 0; i < total; ++i) {
247+
fn(i);
248+
}
249+
#else
250+
if (tp != nullptr) {
251+
tp->SimpleParallelFor(total, fn);
252+
} else {
253+
for (std::ptrdiff_t i = 0; i < total; ++i) {
254+
// In many cases, fn can be inlined here.
255+
fn(i);
256+
}
257+
}
258+
#endif
259+
}
240260

241261
/**
242262
* Tries to call the given function in parallel, with calls split into (num_batches) batches.

onnxruntime/core/common/threadpool.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ ThreadPool::ThreadPool(Eigen::ThreadPoolInterface* user_threadpool, Eigen::Alloc
109109
}
110110

111111
ThreadPool::~ThreadPool() = default;
112-
void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, std::function<void(std::ptrdiff_t)> fn) {
112+
113+
void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, const std::function<void(std::ptrdiff_t)>& fn) {
113114
if (total <= 0)
114115
return;
115116

@@ -320,4 +321,4 @@ Eigen::ThreadPoolInterface* ThreadPool::AsEigenThreadPool() const {
320321
return underlying_threadpool_;
321322
}
322323
} // namespace concurrency
323-
} // namespace onnxruntime
324+
} // namespace onnxruntime

onnxruntime/core/providers/cpu/math/top_k.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ static void FindTopKElements(const Tensor* input, const TensorShape& input_shape
164164
const int64_t num_blocks = input_shape[axis_parsed];
165165
const int64_t block_slice = reduced_cols / k;
166166

167-
int64_t tp_threads = threadpool != nullptr ? threadpool->NumThreads() : 1;
167+
int64_t tp_threads = concurrency::ThreadPool::NumThreads(threadpool);
168168
int64_t num_threads = std::min(tp_threads, rows); // split on rows so can't have more threads than rows
169169

170170
// rough attempt to make sure there's enough work for each thread. if there's insufficient work the usage of
@@ -326,7 +326,8 @@ static void FindTopKElements(const Tensor* input, const TensorShape& input_shape
326326
// we want to re-use the storage variables in each lambda as much as possible to minimize allocations
327327
// on each iteration, so the lambda does multiple rows. e.g. the data_holder and indices_data vectors.
328328
// the alternative would be to use TryBatchParallelFor with the lambda doing one row.
329-
threadpool->SimpleParallelFor(num_threads, find_top_k);
329+
// Use TrySimpleParallelFor so openmp is supported correctly
330+
concurrency::ThreadPool::TrySimpleParallelFor(threadpool, num_threads, find_top_k);
330331
}
331332
}
332333

0 commit comments

Comments
 (0)