Skip to content

Commit f1b9bd1

Browse files
lhamesPriyanshu3820
authored andcommitted
[orc-rt] Introduce Task and TaskDispatcher APIs and implementations. (llvm#168514)
Introduces the Task and TaskDispatcher interfaces (TaskDispatcher.h), ThreadPoolTaskDispatcher implementation (ThreadPoolTaskDispatch.h), and updates Session to include a TaskDispatcher instance that can be used to run tasks. TaskDispatcher's introduction is motivated by the need to handle calls to JIT'd code initiated from the controller process: Incoming calls will be wrapped in Tasks and dispatched. Session shutdown will wait on TaskDispatcher shutdown, ensuring that all Tasks are run or destroyed prior to the Session being destroyed.
1 parent 8ef8fb5 commit f1b9bd1

File tree

11 files changed

+467
-25
lines changed

11 files changed

+467
-25
lines changed

orc-rt/include/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ set(ORC_RT_HEADERS
2222
orc-rt/SPSMemoryFlags.h
2323
orc-rt/SPSWrapperFunction.h
2424
orc-rt/SPSWrapperFunctionBuffer.h
25+
orc-rt/TaskDispatcher.h
26+
orc-rt/ThreadPoolTaskDispatcher.h
2527
orc-rt/WrapperFunction.h
2628
orc-rt/bind.h
2729
orc-rt/bit.h

orc-rt/include/orc-rt/Session.h

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515

1616
#include "orc-rt/Error.h"
1717
#include "orc-rt/ResourceManager.h"
18+
#include "orc-rt/TaskDispatcher.h"
1819
#include "orc-rt/move_only_function.h"
1920

2021
#include "orc-rt-c/CoreTypes.h"
2122

23+
#include <condition_variable>
2224
#include <memory>
2325
#include <mutex>
2426
#include <vector>
@@ -39,7 +41,10 @@ class Session {
3941
///
4042
/// Note that entry into the reporter is not synchronized: it may be
4143
/// called from multiple threads concurrently.
42-
Session(ErrorReporterFn ReportError) : ReportError(std::move(ReportError)) {}
44+
Session(std::unique_ptr<TaskDispatcher> Dispatcher,
45+
ErrorReporterFn ReportError)
46+
: Dispatcher(std::move(Dispatcher)), ReportError(std::move(ReportError)) {
47+
}
4348

4449
// Sessions are not copyable or moveable.
4550
Session(const Session &) = delete;
@@ -49,6 +54,9 @@ class Session {
4954

5055
~Session();
5156

57+
/// Dispatch a task using the Session's TaskDispatcher.
58+
void dispatch(std::unique_ptr<Task> T) { Dispatcher->dispatch(std::move(T)); }
59+
5260
/// Report an error via the ErrorReporter function.
5361
void reportError(Error Err) { ReportError(std::move(Err)); }
5462

@@ -67,12 +75,21 @@ class Session {
6775
}
6876

6977
private:
70-
void shutdownNext(OnShutdownCompleteFn OnShutdownComplete, Error Err,
78+
void shutdownNext(Error Err,
7179
std::vector<std::unique_ptr<ResourceManager>> RemainingRMs);
7280

73-
std::mutex M;
81+
void shutdownComplete();
82+
83+
std::unique_ptr<TaskDispatcher> Dispatcher;
7484
ErrorReporterFn ReportError;
85+
86+
enum class SessionState { Running, ShuttingDown, Shutdown };
87+
88+
std::mutex M;
89+
SessionState State = SessionState::Running;
90+
std::condition_variable StateCV;
7591
std::vector<std::unique_ptr<ResourceManager>> ResourceMgrs;
92+
std::vector<OnShutdownCompleteFn> ShutdownCallbacks;
7693
};
7794

7895
inline orc_rt_SessionRef wrap(Session *S) noexcept {
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===----------- TaskDispatcher.h - Task dispatch utils ---------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Task and TaskDispatcher classes.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef ORC_RT_TASKDISPATCHER_H
14+
#define ORC_RT_TASKDISPATCHER_H
15+
16+
#include "orc-rt/RTTI.h"
17+
18+
#include <memory>
19+
#include <utility>
20+
21+
namespace orc_rt {
22+
23+
/// Represents an abstract task to be run.
24+
class Task : public RTTIExtends<Task, RTTIRoot> {
25+
public:
26+
virtual ~Task();
27+
virtual void run() = 0;
28+
};
29+
30+
/// Base class for generic tasks.
31+
class GenericTask : public RTTIExtends<GenericTask, Task> {};
32+
33+
/// Generic task implementation.
34+
template <typename FnT> class GenericTaskImpl : public GenericTask {
35+
public:
36+
GenericTaskImpl(FnT &&Fn) : Fn(std::forward<FnT>(Fn)) {}
37+
void run() override { Fn(); }
38+
39+
private:
40+
FnT Fn;
41+
};
42+
43+
/// Create a generic task from a function object.
44+
template <typename FnT> std::unique_ptr<GenericTask> makeGenericTask(FnT &&Fn) {
45+
return std::make_unique<GenericTaskImpl<std::decay_t<FnT>>>(
46+
std::forward<FnT>(Fn));
47+
}
48+
49+
/// Abstract base for classes that dispatch Tasks.
50+
class TaskDispatcher {
51+
public:
52+
virtual ~TaskDispatcher();
53+
54+
/// Run the given task.
55+
virtual void dispatch(std::unique_ptr<Task> T) = 0;
56+
57+
/// Called by Session. Should cause further dispatches to be rejected, and
58+
/// wait until all previously dispatched tasks have completed.
59+
virtual void shutdown() = 0;
60+
};
61+
62+
} // End namespace orc_rt
63+
64+
#endif // ORC_RT_TASKDISPATCHER_H
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===--- ThreadPoolTaskDispatcher.h - Run tasks in thread pool --*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// ThreadPoolTaskDispatcher implementation.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef ORC_RT_THREADPOOLTASKDISPATCHER_H
14+
#define ORC_RT_THREADPOOLTASKDISPATCHER_H
15+
16+
#include "orc-rt/TaskDispatcher.h"
17+
18+
#include <condition_variable>
19+
#include <mutex>
20+
#include <thread>
21+
#include <vector>
22+
23+
namespace orc_rt {
24+
25+
/// Thread-pool based TaskDispatcher.
26+
///
27+
/// Will spawn NumThreads threads to run dispatched Tasks.
28+
class ThreadPoolTaskDispatcher : public TaskDispatcher {
29+
public:
30+
ThreadPoolTaskDispatcher(size_t NumThreads);
31+
~ThreadPoolTaskDispatcher() override;
32+
void dispatch(std::unique_ptr<Task> T) override;
33+
void shutdown() override;
34+
35+
private:
36+
void taskLoop();
37+
38+
std::vector<std::thread> Threads;
39+
40+
std::mutex M;
41+
bool AcceptingTasks = true;
42+
std::condition_variable CV;
43+
std::vector<std::unique_ptr<Task>> PendingTasks;
44+
};
45+
46+
} // End namespace orc_rt
47+
48+
#endif // ORC_RT_THREADPOOLTASKDISPATCHER_H

orc-rt/lib/executor/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ set(files
44
RTTI.cpp
55
Session.cpp
66
SimpleNativeMemoryMap.cpp
7+
TaskDispatcher.cpp
8+
ThreadPoolTaskDispatcher.cpp
79
)
810

911
add_library(orc-rt-executor STATIC ${files})

orc-rt/lib/executor/Session.cpp

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212

1313
#include "orc-rt/Session.h"
1414

15-
#include <future>
16-
1715
namespace orc_rt {
1816

1917
Session::~Session() { waitForShutdown(); }
@@ -23,38 +21,62 @@ void Session::shutdown(OnShutdownCompleteFn OnShutdownComplete) {
2321

2422
{
2523
std::scoped_lock<std::mutex> Lock(M);
24+
ShutdownCallbacks.push_back(std::move(OnShutdownComplete));
25+
26+
// If somebody else has already called shutdown then there's nothing further
27+
// for us to do here.
28+
if (State >= SessionState::ShuttingDown)
29+
return;
30+
31+
State = SessionState::ShuttingDown;
2632
std::swap(ResourceMgrs, ToShutdown);
2733
}
2834

29-
shutdownNext(std::move(OnShutdownComplete), Error::success(),
30-
std::move(ToShutdown));
35+
shutdownNext(Error::success(), std::move(ToShutdown));
3136
}
3237

3338
void Session::waitForShutdown() {
34-
std::promise<void> P;
35-
auto F = P.get_future();
36-
37-
shutdown([P = std::move(P)]() mutable { P.set_value(); });
38-
39-
F.wait();
39+
shutdown([]() {});
40+
std::unique_lock<std::mutex> Lock(M);
41+
StateCV.wait(Lock, [&]() { return State == SessionState::Shutdown; });
4042
}
4143

4244
void Session::shutdownNext(
43-
OnShutdownCompleteFn OnComplete, Error Err,
44-
std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
45+
Error Err, std::vector<std::unique_ptr<ResourceManager>> RemainingRMs) {
4546
if (Err)
4647
reportError(std::move(Err));
4748

4849
if (RemainingRMs.empty())
49-
return OnComplete();
50+
return shutdownComplete();
5051

5152
auto NextRM = std::move(RemainingRMs.back());
5253
RemainingRMs.pop_back();
53-
NextRM->shutdown([this, RemainingRMs = std::move(RemainingRMs),
54-
OnComplete = std::move(OnComplete)](Error Err) mutable {
55-
shutdownNext(std::move(OnComplete), std::move(Err),
56-
std::move(RemainingRMs));
57-
});
54+
NextRM->shutdown(
55+
[this, RemainingRMs = std::move(RemainingRMs)](Error Err) mutable {
56+
shutdownNext(std::move(Err), std::move(RemainingRMs));
57+
});
58+
}
59+
60+
void Session::shutdownComplete() {
61+
62+
std::unique_ptr<TaskDispatcher> TmpDispatcher;
63+
std::vector<OnShutdownCompleteFn> TmpShutdownCallbacks;
64+
{
65+
std::lock_guard<std::mutex> Lock(M);
66+
TmpDispatcher = std::move(Dispatcher);
67+
TmpShutdownCallbacks = std::move(ShutdownCallbacks);
68+
}
69+
70+
TmpDispatcher->shutdown();
71+
72+
for (auto &OnShutdownComplete : TmpShutdownCallbacks)
73+
OnShutdownComplete();
74+
75+
{
76+
std::lock_guard<std::mutex> Lock(M);
77+
State = SessionState::Shutdown;
78+
}
79+
StateCV.notify_all();
5880
}
5981

6082
} // namespace orc_rt
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- TaskDispatch.cpp ---------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Contains the implementation of APIs in the orc-rt/TaskDispatch.h header.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "orc-rt/TaskDispatcher.h"
14+
15+
namespace orc_rt {
16+
17+
Task::~Task() = default;
18+
TaskDispatcher::~TaskDispatcher() = default;
19+
20+
} // namespace orc_rt
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
//===- ThreadPoolTaskDispatch.cpp -----------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Contains the implementation of APIs in the orc-rt/ThreadPoolTaskDispatch.h
10+
// header.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "orc-rt/ThreadPoolTaskDispatcher.h"
15+
16+
#include <cassert>
17+
18+
namespace orc_rt {
19+
20+
ThreadPoolTaskDispatcher::~ThreadPoolTaskDispatcher() {
21+
assert(!AcceptingTasks && "shutdown was not run");
22+
}
23+
24+
ThreadPoolTaskDispatcher::ThreadPoolTaskDispatcher(size_t NumThreads) {
25+
Threads.reserve(NumThreads);
26+
for (size_t I = 0; I < NumThreads; ++I)
27+
Threads.emplace_back([this]() { taskLoop(); });
28+
}
29+
30+
void ThreadPoolTaskDispatcher::dispatch(std::unique_ptr<Task> T) {
31+
{
32+
std::scoped_lock<std::mutex> Lock(M);
33+
if (!AcceptingTasks)
34+
return;
35+
PendingTasks.push_back(std::move(T));
36+
}
37+
CV.notify_one();
38+
}
39+
40+
void ThreadPoolTaskDispatcher::shutdown() {
41+
{
42+
std::scoped_lock<std::mutex> Lock(M);
43+
assert(AcceptingTasks && "ThreadPoolTaskDispatcher already shut down?");
44+
AcceptingTasks = false;
45+
}
46+
CV.notify_all();
47+
for (auto &Thread : Threads)
48+
Thread.join();
49+
}
50+
51+
void ThreadPoolTaskDispatcher::taskLoop() {
52+
while (true) {
53+
std::unique_ptr<Task> T;
54+
{
55+
std::unique_lock<std::mutex> Lock(M);
56+
CV.wait(Lock,
57+
[this]() { return !PendingTasks.empty() || !AcceptingTasks; });
58+
59+
if (!AcceptingTasks && PendingTasks.empty())
60+
return;
61+
62+
T = std::move(PendingTasks.back());
63+
PendingTasks.pop_back();
64+
}
65+
66+
T->run();
67+
}
68+
}
69+
70+
} // namespace orc_rt

orc-rt/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_orc_rt_unittest(CoreTests
3131
SPSMemoryFlagsTest.cpp
3232
SPSWrapperFunctionTest.cpp
3333
SPSWrapperFunctionBufferTest.cpp
34+
ThreadPoolTaskDispatcherTest.cpp
3435
WrapperFunctionBufferTest.cpp
3536
bind-test.cpp
3637
bit-test.cpp

0 commit comments

Comments
 (0)