Skip to content

Commit ff2f6c8

Browse files
authored
Merge pull request #178 from NVIDIA/public_op_registration
refactor(//core/execution): Move to public APIs for operator registration
2 parents 1f5c702 + 3eb1c63 commit ff2f6c8

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

core/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
8585
execute_node_inputs.push_back(engine_node->outputs()[0]);
8686

8787
// Create the actual execution node trt::execute_engine using the assembled inputs
88-
auto execute_node = g->create(c10::Symbol::fromQualString("trt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(execute_node_inputs), 1);
88+
auto execute_node = g->create(c10::Symbol::fromQualString("tensorrt::execute_engine"), torch::jit::ArrayRef<torch::jit::Value*>(execute_node_inputs), 1);
8989
g->block()->appendNode(execute_node);
9090
execute_node->outputs()[0]->setType(c10::ListType::ofTensors());
9191

core/execution/register_trt_op.cpp

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,33 +46,21 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
4646
return outputs;
4747
}
4848

49-
namespace {
50-
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
51-
return c10::AliasAnalysisKind::FROM_SCHEMA;
52-
}
53-
54-
// Switched to a global operator because op implementations need to be non-capturing lambdas in PYT 1.5.0+
55-
torch::jit::RegisterOperators jit_registry({
56-
torch::jit::Operator(
57-
"trt::execute_engine(Tensor[] inputs, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]",
58-
[](torch::jit::Stack& stack) -> int {
59-
// Verify calling convention (right to left or left to right)
60-
auto engine = torch::jit::pop(stack).toCustomClass<TRTEngine>();
61-
LOG_DEBUG("Attempting to run engine (ID: " << std::hex << engine->name << ")");
49+
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> engine) {
50+
// Verify calling convention (right to left or left to right)
51+
LOG_DEBUG("Attempting to run engine (ID: " << std::hex << engine->name << ")");
6252

63-
auto inputs = torch::jit::pop(stack).toTensorVector();
53+
auto io = engine->num_io;
54+
auto ctx = engine->exec_ctx;
55+
auto outputs = RunCudaEngine(ctx, io, inputs);
6456

65-
auto io = engine->num_io;
57+
return outputs;
58+
}
6659

67-
auto ctx = engine->exec_ctx;
68-
auto outputs = RunCudaEngine(ctx, io, inputs);
69-
torch::jit::push(stack, std::move(outputs));
70-
return 0;
71-
},
72-
aliasAnalysisFromSchema())
73-
});
60+
TORCH_LIBRARY(tensorrt, m) {
61+
m.def("execute_engine", execute_engine);
62+
}
7463

75-
} // namespace
7664
} // namespace execution
7765
} // namespace core
7866
} // namespace trtorch

0 commit comments

Comments
 (0)