@@ -46,33 +46,21 @@ std::vector<at::Tensor> RunCudaEngine(nvinfer1::IExecutionContext* ctx, std::pai
46
46
return outputs;
47
47
}
48
48
49
- namespace {
50
- c10::AliasAnalysisKind aliasAnalysisFromSchema () {
51
- return c10::AliasAnalysisKind::FROM_SCHEMA;
52
- }
53
-
54
- // Switched to a global operator because op implementations need to be non-capturing lambdas in PYT 1.5.0+
55
- torch::jit::RegisterOperators jit_registry ({
56
- torch::jit::Operator (
57
- " trt::execute_engine(Tensor[] inputs, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]" ,
58
- [](torch::jit::Stack& stack) -> int {
59
- // Verify calling convention (right to left or left to right)
60
- auto engine = torch::jit::pop (stack).toCustomClass <TRTEngine>();
61
- LOG_DEBUG (" Attempting to run engine (ID: " << std::hex << engine->name << " )" );
49
+ std::vector<at::Tensor> execute_engine (std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> engine) {
50
+ // Verify calling convention (right to left or left to right)
51
+ LOG_DEBUG (" Attempting to run engine (ID: " << std::hex << engine->name << " )" );
62
52
63
- auto inputs = torch::jit::pop (stack).toTensorVector ();
53
+ auto io = engine->num_io ;
54
+ auto ctx = engine->exec_ctx ;
55
+ auto outputs = RunCudaEngine (ctx, io, inputs);
64
56
65
- auto io = engine->num_io ;
57
+ return outputs;
58
+ }
66
59
67
- auto ctx = engine->exec_ctx ;
68
- auto outputs = RunCudaEngine (ctx, io, inputs);
69
- torch::jit::push (stack, std::move (outputs));
70
- return 0 ;
71
- },
72
- aliasAnalysisFromSchema ())
73
- });
60
+ TORCH_LIBRARY (tensorrt, m) {
61
+ m.def (" execute_engine" , execute_engine);
62
+ }
74
63
75
- } // namespace
76
64
} // namespace execution
77
65
} // namespace core
78
66
} // namespace trtorch
0 commit comments