Skip to content

Commit fe08950

Browse files
authored
Merge pull request #1036 from NVIDIA/restructure_runtime_registration
refactor(//core/runtime): Moving dependent static initialization into
2 parents d63a483 + 5d0a605 commit fe08950

File tree

6 files changed

+53
-56
lines changed

6 files changed

+53
-56
lines changed

core/runtime/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ cc_library(
1212
srcs = [
1313
"CudaDevice.cpp",
1414
"DeviceList.cpp",
15+
"execute_engine.cpp",
1516
"TRTEngine.cpp",
16-
"register_trt_op.cpp",
17+
"register_jit_hooks.cpp",
1718
"runtime.cpp"
1819
],
1920
hdrs = [

core/runtime/TRTEngine.cpp

Lines changed: 1 addition & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ namespace torch_tensorrt {
1111
namespace core {
1212
namespace runtime {
1313

14-
typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
15-
1614
std::string slugify(std::string s) {
1715
std::replace(s.begin(), s.end(), '.', '_');
1816
return s;
@@ -35,7 +33,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) {
3533
std::string _name = serialized_info[NAME_IDX];
3634
std::string engine_info = serialized_info[ENGINE_IDX];
3735

38-
CudaDevice cuda_device = deserialize_device(serialized_info[DEVICE_IDX]);
36+
CudaDevice cuda_device(serialized_info[DEVICE_IDX]);
3937
new (this) TRTEngine(_name, engine_info, cuda_device);
4038
}
4139

@@ -124,43 +122,6 @@ std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) {
124122
return os;
125123
}
126124

127-
// TODO: Implement a call method
128-
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
129-
// auto input_vec = inputs.vec();
130-
// auto output_vec = RunCudaEngine(exec_ctx, num_io, input_vec);
131-
//
132-
// return c10::List<at::Tensor>(output_vec);
133-
// }
134-
135-
namespace {
136-
static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
137-
torch::class_<TRTEngine>("tensorrt", "Engine")
138-
.def(torch::init<std::vector<std::string>>())
139-
// TODO: .def("__call__", &TRTEngine::Run)
140-
// TODO: .def("run", &TRTEngine::Run)
141-
.def("__str__", &TRTEngine::to_str)
142-
.def_pickle(
143-
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
144-
// Serialize TensorRT engine
145-
auto serialized_trt_engine = self->cuda_engine->serialize();
146-
147-
// Adding device info related meta data to the serialized file
148-
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
149-
150-
std::vector<std::string> serialize_info;
151-
serialize_info.resize(ENGINE_IDX + 1);
152-
153-
serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
154-
serialize_info[NAME_IDX] = self->name;
155-
serialize_info[DEVICE_IDX] = serialize_device(self->device_info);
156-
serialize_info[ENGINE_IDX] = trt_engine;
157-
return serialize_info;
158-
},
159-
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
160-
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
161-
});
162-
} // namespace
163-
164125
} // namespace runtime
165126
} // namespace core
166127
} // namespace torch_tensorrt

core/runtime/register_trt_op.cpp renamed to core/runtime/execute_engine.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
120120
return outputs;
121121
}
122122

123-
TORCH_LIBRARY(tensorrt, m) {
124-
m.def("execute_engine", execute_engine);
125-
}
126-
127123
} // namespace runtime
128124
} // namespace core
129125
} // namespace torch_tensorrt

core/runtime/register_jit_hooks.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "core/runtime/runtime.h"
2+
3+
namespace torch_tensorrt {
4+
namespace core {
5+
namespace runtime {
6+
namespace {
7+
8+
// TODO: Implement a call method
9+
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
10+
// auto input_vec = inputs.vec();
11+
// auto output_vec = RunCudaEngine(exec_ctx, num_io, input_vec);
12+
//
13+
// return c10::List<at::Tensor>(output_vec);
14+
// }
15+
static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
16+
torch::class_<TRTEngine>("tensorrt", "Engine")
17+
.def(torch::init<std::vector<std::string>>())
18+
// TODO: .def("__call__", &TRTEngine::Run)
19+
// TODO: .def("run", &TRTEngine::Run)
20+
.def("__str__", &TRTEngine::to_str)
21+
.def_pickle(
22+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
23+
// Serialize TensorRT engine
24+
auto serialized_trt_engine = self->cuda_engine->serialize();
25+
26+
// Adding device info related meta data to the serialized file
27+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
28+
29+
std::vector<std::string> serialize_info;
30+
serialize_info.resize(ENGINE_IDX + 1);
31+
32+
serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
33+
serialize_info[NAME_IDX] = self->name;
34+
serialize_info[DEVICE_IDX] = self->device_info.serialize();
35+
serialize_info[ENGINE_IDX] = trt_engine;
36+
return serialize_info;
37+
},
38+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
39+
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
40+
});
41+
42+
TORCH_LIBRARY(tensorrt, m) {
43+
m.def("execute_engine", execute_engine);
44+
}
45+
46+
} // namespace
47+
} // namespace runtime
48+
} // namespace core
49+
} // namespace torch_tensorrt

core/runtime/runtime.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,6 @@ CudaDevice get_current_device() {
8686
return CudaDevice(device_id, nvinfer1::DeviceType::kGPU);
8787
}
8888

89-
std::string serialize_device(CudaDevice& cuda_device) {
90-
return cuda_device.serialize();
91-
}
92-
93-
CudaDevice deserialize_device(std::string device_info) {
94-
return CudaDevice(device_info);
95-
}
96-
9789
namespace {
9890
static DeviceList cuda_device_list;
9991
}

core/runtime/runtime.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace runtime {
1414

1515
using EngineID = int64_t;
1616
const std::string ABI_VERSION = "3";
17+
typedef enum { ABI_TARGET_IDX = 0, NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
1718

1819
struct CudaDevice {
1920
int64_t id; // CUDA device id
@@ -38,9 +39,6 @@ CudaDevice get_current_device();
3839
c10::optional<CudaDevice> get_most_compatible_device(const CudaDevice& target_device);
3940
std::vector<CudaDevice> find_compatible_devices(const CudaDevice& target_device);
4041

41-
std::string serialize_device(CudaDevice& cuda_device);
42-
CudaDevice deserialize_device(std::string device_info);
43-
4442
struct TRTEngine : torch::CustomClassHolder {
4543
// Each engine needs it's own runtime object
4644
std::shared_ptr<nvinfer1::IRuntime> rt;

0 commit comments

Comments
 (0)