@@ -45,28 +45,28 @@ TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) {
4545 unsigned int maxBatchSize{1 };
4646 int memory_limit = 1U << 30 ; // 1G
4747
48- auto builder = TrtUniquePtr <IBuilder>(nvinfer1::createInferBuilder (gLogger ));
48+ auto builder = TrtUnqPtr <IBuilder>(nvinfer1::createInferBuilder (gLogger ));
4949 if (!builder) {
5050 cout << " create builder failed\n " ;
5151 std::abort ();
5252 }
5353
5454 const auto explicitBatch = 1U << static_cast <uint32_t >(
5555 nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH );
56- auto network = TrtUniquePtr <INetworkDefinition>(
56+ auto network = TrtUnqPtr <INetworkDefinition>(
5757 builder->createNetworkV2 (explicitBatch));
5858 if (!network) {
5959 cout << " create network failed\n " ;
6060 std::abort ();
6161 }
6262
63- auto config = TrtUniquePtr <IBuilderConfig>(builder->createBuilderConfig ());
63+ auto config = TrtUnqPtr <IBuilderConfig>(builder->createBuilderConfig ());
6464 if (!config) {
6565 cout << " create builder config failed\n " ;
6666 std::abort ();
6767 }
6868
69- auto parser = TrtUniquePtr <nvonnxparser::IParser>(nvonnxparser::createParser (*network, gLogger ));
69+ auto parser = TrtUnqPtr <nvonnxparser::IParser>(nvonnxparser::createParser (*network, gLogger ));
7070 if (!parser) {
7171 cout << " create parser failed\n " ;
7272 std::abort ();
@@ -84,25 +84,45 @@ TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) {
8484 if (use_fp16 && builder->platformHasFastFp16 ()) {
8585 config->setFlag (nvinfer1::BuilderFlag::kFP16 ); // fp16
8686 }
87- // TODO: see if use dla or int8
8887
8988 auto output = network->getOutput (0 );
9089 output->setType (nvinfer1::DataType::kINT32 );
9190
91+ cout << " start to build \n " ;
92+ CudaStreamUnqPtr stream (new cudaStream_t);
93+ if (cudaStreamCreate (stream.get ())) {
94+ cout << " create stream failed\n " ;
95+ std::abort ();
96+ }
97+ config->setProfileStream (*stream);
98+
99+ auto plan = TrtUnqPtr<IHostMemory>(builder->buildSerializedNetwork (*network, *config));
100+ if (!plan) {
101+ cout << " serialization failed\n " ;
102+ std::abort ();
103+ }
104+
105+ auto runtime = TrtUnqPtr<IRuntime>(nvinfer1::createInferRuntime (gLogger ));
106+ if (!plan) {
107+ cout << " create runtime failed\n " ;
108+ std::abort ();
109+ }
110+
92111 TrtSharedEnginePtr engine = shared_engine_ptr (
93- builder-> buildEngineWithConfig (*network, *config ));
112+ runtime-> deserializeCudaEngine (plan-> data (), plan-> size () ));
94113 if (!engine) {
95114 cout << " create engine failed\n " ;
96115 std::abort ();
97116 }
117+ cout << " done build engine \n " ;
98118
99119 return engine;
100120}
101121
102122
103123void serialize (TrtSharedEnginePtr engine, string save_path) {
104124
105- auto trt_stream = TrtUniquePtr <IHostMemory>(engine->serialize ());
125+ auto trt_stream = TrtUnqPtr <IHostMemory>(engine->serialize ());
106126 if (!trt_stream) {
107127 cout << " serialize engine failed\n " ;
108128 std::abort ();
@@ -132,7 +152,7 @@ TrtSharedEnginePtr deserialize(string serpth) {
132152 ifile.close ();
133153 cout << " model size: " << mdsize << endl;
134154
135- auto runtime = TrtUniquePtr <IRuntime>(nvinfer1::createInferRuntime (gLogger ));
155+ auto runtime = TrtUnqPtr <IRuntime>(nvinfer1::createInferRuntime (gLogger ));
136156 TrtSharedEnginePtr engine = shared_engine_ptr (
137157 runtime->deserializeCudaEngine ((void *)&buf[0 ], mdsize, nullptr ));
138158 return engine;
@@ -149,7 +169,7 @@ vector<int> infer_with_engine(TrtSharedEnginePtr engine, vector<float>& data) {
149169 vector<void *> buffs (2 );
150170 vector<int > res (out_size);
151171
152- auto context = TrtUniquePtr <IExecutionContext>(engine->createExecutionContext ());
172+ auto context = TrtUnqPtr <IExecutionContext>(engine->createExecutionContext ());
153173 if (!context) {
154174 cout << " create execution context failed\n " ;
155175 std::abort ();
@@ -166,34 +186,32 @@ vector<int> infer_with_engine(TrtSharedEnginePtr engine, vector<float>& data) {
166186 cout << " allocate memory failed\n " ;
167187 std::abort ();
168188 }
169- cudaStream_t stream;
170- state = cudaStreamCreate (&stream);
171- if (state) {
189+ CudaStreamUnqPtr stream (new cudaStream_t);
190+ if (cudaStreamCreate (stream.get ())) {
172191 cout << " create stream failed\n " ;
173192 std::abort ();
174193 }
175194
176195 state = cudaMemcpyAsync (
177196 buffs[0 ], &data[0 ], in_size * sizeof (float ),
178- cudaMemcpyHostToDevice, stream);
197+ cudaMemcpyHostToDevice, * stream);
179198 if (state) {
180199 cout << " transmit to device failed\n " ;
181200 std::abort ();
182201 }
183- context->enqueueV2 (&buffs[0 ], stream, nullptr );
202+ context->enqueueV2 (&buffs[0 ], * stream, nullptr );
184203 // context->enqueue(1, &buffs[0], stream, nullptr);
185204 state = cudaMemcpyAsync (
186205 &res[0 ], buffs[1 ], out_size * sizeof (int ),
187- cudaMemcpyDeviceToHost, stream);
206+ cudaMemcpyDeviceToHost, * stream);
188207 if (state) {
189208 cout << " transmit to host failed \n " ;
190209 std::abort ();
191210 }
192- cudaStreamSynchronize (stream);
211+ cudaStreamSynchronize (* stream);
193212
194213 cudaFree (buffs[0 ]);
195214 cudaFree (buffs[1 ]);
196- cudaStreamDestroy (stream);
197215
198216 return res;
199217}
@@ -210,7 +228,7 @@ void test_fps_with_engine(TrtSharedEnginePtr engine) {
210228 const int in_size{batchsize * 3 * iH * iW};
211229 const int out_size{batchsize * oH * oW};
212230
213- auto context = TrtUniquePtr <IExecutionContext>(engine->createExecutionContext ());
231+ auto context = TrtUnqPtr <IExecutionContext>(engine->createExecutionContext ());
214232 if (!context) {
215233 cout << " create execution context failed\n " ;
216234 std::abort ();
0 commit comments