Skip to content

Commit 40f0cca

Browse files
cccclaifacebook-github-bot
authored andcommitted
Add backend runtime context to backend.execute (#211)
Summary: Pull Request resolved: #211 We're adding backend runtime context which comes with mainly two extra functions - Temp allocator (life span is per execute) - Event tracer (profiling inside delegate) Reviewed By: tarun292, dbort Differential Revision: D48872100 fbshipit-source-id: 8e8019a4e1fa280ea54831d77a54e879c4f572d1
1 parent 2c9057a commit 40f0cca

File tree

9 files changed

+40
-11
lines changed

9 files changed

+40
-11
lines changed

backends/qnnpack/QNNPackBackend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,10 @@ class QnnpackBackend final : public PyTorchBackendInterface {
195195
return executor;
196196
}
197197

198-
Error execute(DelegateHandle* handle, EValue** args) const override {
198+
Error execute(
199+
__ET_UNUSED BackendExecutionContext& context,
200+
DelegateHandle* handle,
201+
EValue** args) const override {
199202
static constexpr size_t kMaxDims = 16;
200203

201204
QNNExecutor* etor = static_cast<QNNExecutor*>(handle);

backends/vulkan/VulkanBackend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,10 @@ class VulkanBackend final : public PyTorchBackendInterface {
270270
return compute_graph;
271271
}
272272

273-
Error execute(DelegateHandle* handle, EValue** args) const override {
273+
Error execute(
274+
__ET_UNUSED BackendExecutionContext& context,
275+
DelegateHandle* handle,
276+
EValue** args) const override {
274277
EXECUTORCH_SCOPE_PROF("VulkanBackend::execute");
275278

276279
at::native::vulkan::ComputeGraph* compute_graph =

backends/xnnpack/runtime/XNNPACKBackend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ class XnnpackBackend final : public PyTorchBackendInterface {
4949
return executor;
5050
}
5151

52-
Error execute(DelegateHandle* handle, EValue** args) const override {
52+
Error execute(
53+
__ET_UNUSED BackendExecutionContext& context,
54+
DelegateHandle* handle,
55+
EValue** args) const override {
5356
auto executor = static_cast<xnnpack::delegate::XNNExecutor*>(handle);
5457

5558
std::vector<Tensor*> input_pointers;

exir/backend/test/demos/rpc/ExecutorBackend.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,10 @@ class ExecutorBackend final : public PyTorchBackendInterface {
140140
return client_method;
141141
}
142142

143-
Error execute(DelegateHandle* handle, EValue** args) const override {
143+
Error execute(
144+
__ET_UNUSED BackendExecutionContext& context,
145+
DelegateHandle* handle,
146+
EValue** args) const override {
144147
Method* client_method = static_cast<Method*>(handle);
145148
auto num_inputs = client_method->inputs_size();
146149
Error status = Error::Ok;

runtime/backend/backend_registry.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
#include <cstring>
1212

13+
#include <executorch/runtime/backend/backend_execution_context.h>
1314
#include <executorch/runtime/core/array_ref.h>
1415
#include <executorch/runtime/core/error.h>
1516
#include <executorch/runtime/core/evalue.h>
@@ -88,8 +89,10 @@ class PyTorchBackendInterface {
8889
* @param[in] args The method’s inputs and outputs.
8990
* @retval Error::Ok if successful.
9091
*/
91-
__ET_NODISCARD virtual Error execute(DelegateHandle* handle, EValue** args)
92-
const = 0;
92+
__ET_NODISCARD virtual Error execute(
93+
BackendExecutionContext& context,
94+
DelegateHandle* handle,
95+
EValue** args) const = 0;
9396

9497
/**
9598
* Responsible for destroying a handle, if it's required for some backend.

runtime/executor/method.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,11 @@ class BackendDelegate final {
115115
}
116116
}
117117

118-
Error Execute(EValue** args) const {
118+
Error Execute(
119+
BackendExecutionContext& backend_execution_context,
120+
EValue** args) const {
119121
EXECUTORCH_SCOPE_PROF("delegate_execute");
120-
return backend_->execute(handle_, args);
122+
return backend_->execute(backend_execution_context, handle_, args);
121123
}
122124

123125
private:
@@ -939,7 +941,9 @@ Error Method::execute_instruction() {
939941
delegate_idx,
940942
n_delegate_,
941943
step_state_.instr_idx);
944+
BackendExecutionContext backend_execution_context;
942945
Error err = delegates_[delegate_idx].Execute(
946+
backend_execution_context,
943947
chain.argument_lists_[step_state_.instr_idx].data());
944948
ET_CHECK_MSG(
945949
err == Error::Ok,

runtime/executor/test/backend_integration_test.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
using namespace ::testing;
3131
using exec_aten::ArrayRef;
32+
using torch::executor::BackendExecutionContext;
3233
using torch::executor::CompileSpec;
3334
using torch::executor::DataLoader;
3435
using torch::executor::DelegateHandle;
@@ -91,7 +92,10 @@ class StubBackend final : public PyTorchBackendInterface {
9192
execute_fn_ = fn;
9293
}
9394

94-
Error execute(DelegateHandle* handle, EValue** args) const override {
95+
Error execute(
96+
__ET_UNUSED BackendExecutionContext& context,
97+
DelegateHandle* handle,
98+
EValue** args) const override {
9599
if (execute_fn_) {
96100
return execute_fn_.value()(handle, args);
97101
}

runtime/executor/test/test_backend_compiler_lib.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ class BackendWithCompiler final : public PyTorchBackendInterface {
136136
// execute and it only supports add, subtract, and constant. In a non toy
137137
// backend you can imagine how this function could be used to actually
138138
// dispatch the inputs to the relevant backend/device.
139-
Error execute(DelegateHandle* handle, EValue** args) const override {
139+
Error execute(
140+
__ET_UNUSED BackendExecutionContext& context,
141+
DelegateHandle* handle,
142+
EValue** args) const override {
140143
EXECUTORCH_SCOPE_PROF("BackendWithCompiler::execute");
141144

142145
// example: [('prim::Constant#1', 14), ('aten::add', 15)]

runtime/executor/test/test_backend_with_delegate_mapping.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ class BackendWithDelegateMapping final : public PyTorchBackendInterface {
110110

111111
// This function doesn't actually execute the op but just prints out the op
112112
// name and the corresponding delegate debug identifier.
113-
Error execute(DelegateHandle* handle, EValue** args) const override {
113+
Error execute(
114+
__ET_UNUSED BackendExecutionContext& context,
115+
DelegateHandle* handle,
116+
EValue** args) const override {
114117
(void)args;
115118
// example: [('prim::Constant#1', 14), ('aten::add', 15)]
116119
auto op_list = static_cast<const DemoOpList*>(handle);

0 commit comments

Comments
 (0)