@@ -11,8 +11,6 @@ namespace torch_tensorrt {
11
11
namespace core {
12
12
namespace runtime {
13
13
14
- typedef enum { ABI_TARGET_IDX = 0 , NAME_IDX, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
15
-
16
14
std::string slugify (std::string s) {
17
15
std::replace (s.begin (), s.end (), ' .' , ' _' );
18
16
return s;
@@ -35,7 +33,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info) {
35
33
std::string _name = serialized_info[NAME_IDX];
36
34
std::string engine_info = serialized_info[ENGINE_IDX];
37
35
38
- CudaDevice cuda_device = deserialize_device (serialized_info[DEVICE_IDX]);
36
+ CudaDevice cuda_device (serialized_info[DEVICE_IDX]);
39
37
new (this ) TRTEngine (_name, engine_info, cuda_device);
40
38
}
41
39
@@ -124,43 +122,6 @@ std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) {
124
122
return os;
125
123
}
126
124
127
- // TODO: Implement a call method
128
- // c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
129
- // auto input_vec = inputs.vec();
130
- // auto output_vec = RunCudaEngine(exec_ctx, num_io, input_vec);
131
- //
132
- // return c10::List<at::Tensor>(output_vec);
133
- // }
134
-
135
- namespace {
136
- static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
137
- torch::class_<TRTEngine>(" tensorrt" , " Engine" )
138
- .def(torch::init<std::vector<std::string>>())
139
- // TODO: .def("__call__", &TRTEngine::Run)
140
- // TODO: .def("run", &TRTEngine::Run)
141
- .def(" __str__" , &TRTEngine::to_str)
142
- .def_pickle(
143
- [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
144
- // Serialize TensorRT engine
145
- auto serialized_trt_engine = self->cuda_engine ->serialize ();
146
-
147
- // Adding device info related meta data to the serialized file
148
- auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
149
-
150
- std::vector<std::string> serialize_info;
151
- serialize_info.resize (ENGINE_IDX + 1 );
152
-
153
- serialize_info[ABI_TARGET_IDX] = ABI_VERSION;
154
- serialize_info[NAME_IDX] = self->name ;
155
- serialize_info[DEVICE_IDX] = serialize_device (self->device_info );
156
- serialize_info[ENGINE_IDX] = trt_engine;
157
- return serialize_info;
158
- },
159
- [](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
160
- return c10::make_intrusive<TRTEngine>(std::move (seralized_info));
161
- });
162
- } // namespace
163
-
164
125
} // namespace runtime
165
126
} // namespace core
166
127
} // namespace torch_tensorrt
0 commit comments