@@ -16,73 +16,62 @@ using ModelInfo = std::pair<std::string, int64_t>;
1616// based on https://github.com/NVIDIA/tensorrt-inference-server/blob/master/src/clients/c++/examples/simple_callback_client.cc
1717
1818template <typename Client>
19- TRTClient<Client>::TRTClient(const edm::ParameterSet ¶ms) : Client(),
20- url_ (params.getParameter<std::string>(" address" ) + ":" + std::to_string(params.getParameter<unsigned >(" port" ))),
21- timeout_(params.getParameter<unsigned >(" timeout" )),
22- modelName_(params.getParameter<std::string>(" modelName" )),
23- batchSize_(params.getParameter<unsigned >(" batchSize" )),
24- ninput_(params.getParameter<unsigned >(" ninput" )),
25- noutput_(params.getParameter<unsigned >(" noutput" ))
19+ TRTClient<Client>::TRTClient(const edm::ParameterSet& params) :
20+ Client (),
21+ url_(params.getParameter<std::string>(" address" )+":"+std::to_string(params.getParameter<unsigned >(" port" ))),
22+ timeout_(params.getParameter<unsigned >(" timeout" )),
23+ modelName_(params.getParameter<std::string>(" modelName" )),
24+ batchSize_(params.getParameter<unsigned >(" batchSize" )),
25+ ninput_(params.getParameter<unsigned >(" ninput" )),
26+ noutput_(params.getParameter<unsigned >(" noutput" ))
2627{
2728}
2829
2930template <typename Client>
30- void TRTClient<Client>::setup()
31- {
31+ void TRTClient<Client>::setup() {
3232 auto err = nic::InferGrpcContext::Create (&context_, url_, modelName_, -1 , false );
33- if (!err.IsOk ())
34- throw cms::Exception (" BadGrpc" ) << " unable to create inference context: " << err;
35-
36- // nic::ServerStatusGrpcContext::Create(&server_ctx_, url_, false);
37- // if (!err.IsOk())
38- // throw cms::Exception("BadServer") << "unable to create server inference context: " << err;
33+ if (!err.IsOk ()) throw cms::Exception (" BadGrpc" ) << " unable to create inference context: " << err;
3934
4035 std::unique_ptr<nic::InferContext::Options> options;
4136 nic::InferContext::Options::Create (&options);
4237
4338 options->SetBatchSize (batchSize_);
44- for (const auto &output : context_->Outputs ())
45- {
39+ for (const auto & output : context_->Outputs ()) {
4640 options->AddRawResult (output);
4741 }
4842 context_->SetRunOptions (*options);
4943
50- const std::vector<std::shared_ptr<nic::InferContext::Input>> & nicinputs = context_->Inputs ();
44+ const std::vector<std::shared_ptr<nic::InferContext::Input>>& nicinputs = context_->Inputs ();
5145 nicinput_ = nicinputs[0 ];
5246 nicinput_->Reset ();
5347
5448 auto t2 = std::chrono::high_resolution_clock::now ();
5549 std::vector<int64_t > input_shape;
56- for (unsigned i0 = 0 ; i0 < batchSize_; i0++)
57- {
50+ for (unsigned i0 = 0 ; i0 < batchSize_; i0++) {
5851 float *arr = &(this ->input_ .data ()[i0 * ninput_]);
5952 nic::Error err1 = nicinput_->SetRaw (reinterpret_cast <const uint8_t *>(arr), ninput_ * sizeof (float ));
6053 }
6154 auto t3 = std::chrono::high_resolution_clock::now ();
62- edm::LogInfo (" TRTClient" ) << " Image array time: " << std::chrono::duration_cast<std::chrono::microseconds>(t3 - t2).count ();
55+ edm::LogInfo (" TRTClient" ) << " Image array time: " << std::chrono::duration_cast<std::chrono::microseconds>(t3- t2).count ();
6356}
6457
6558template <typename Client>
66- void TRTClient<Client>::getResults(const std::unique_ptr<nic::InferContext::Result> &result)
67- {
59+ void TRTClient<Client>::getResults(const std::unique_ptr<nic::InferContext::Result>& result) {
6860 auto t2 = std::chrono::high_resolution_clock::now ();
69- this ->output_ .resize (noutput_ * batchSize_, 0 .f );
70- for (unsigned i0 = 0 ; i0 < batchSize_; i0++)
71- {
61+ this ->output_ .resize (noutput_*batchSize_,0 .f );
62+ for (unsigned i0 = 0 ; i0 < batchSize_; i0++) {
7263 const uint8_t *r0;
7364 size_t content_byte_size;
7465 result->GetRaw (i0, &r0, &content_byte_size);
75- const float *lVal = reinterpret_cast <const float *>(r0);
76- for (unsigned i1 = 0 ; i1 < noutput_; i1++)
77- this ->output_ [i0 * noutput_ + i1] = lVal[i1]; // This should be replaced with a memcpy
66+ const float *lVal = reinterpret_cast <const float *>(r0);
67+ for (unsigned i1 = 0 ; i1 < noutput_; i1++) this ->output_ [i0*noutput_+i1] = lVal[i1]; // This should be replaced with a memcpy
7868 }
7969 auto t3 = std::chrono::high_resolution_clock::now ();
80- edm::LogInfo (" TRTClient" ) << " Output time: " << std::chrono::duration_cast<std::chrono::microseconds>(t3 - t2).count ();
70+ edm::LogInfo (" TRTClient" ) << " Output time: " << std::chrono::duration_cast<std::chrono::microseconds>(t3- t2).count ();
8171}
8272
8373template <typename Client>
84- void TRTClient<Client>::predictImpl()
85- {
74+ void TRTClient<Client>::predictImpl(){
8675 // common operations first
8776 setup ();
8877
@@ -91,7 +80,7 @@ void TRTClient<Client>::predictImpl()
9180 std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
9281 nic::Error err = context_->Run (&results);
9382 if (!err.IsOk ()) {
94- std::cout << " Could not read the result" << " : " << err << std::endl ;
83+ edm::LogWarning ( " TRTClient " ) << " Could not read the result" << " : " << err;
9584 this ->output_ .resize (noutput_ * batchSize_, 0 .f );
9685 } else {
9786 auto t3 = std::chrono::high_resolution_clock::now ();
@@ -102,28 +91,21 @@ void TRTClient<Client>::predictImpl()
10291
10392// specialization for true async
10493template <>
105- void TRTClientAsync::predictImpl ()
106- {
94+ void TRTClientAsync::predictImpl (){
10795 // common operations first
108- try
109- {
96+ try {
11097 setup ();
11198 }
112- catch (...)
113- {
99+ catch (...) {
114100 finish (std::current_exception ());
115101 return ;
116102 }
117103
118104 // non-blocking call
119105
120- // Get the status of the server prior to the request being made.
121- // std::map<std::string, ni::ModelStatus> start_status;
122- // GetServerSideStatus(&start_status);
123-
124106 auto t2 = std::chrono::high_resolution_clock::now ();
125107 nic::Error erro0 = context_->AsyncRun (
126- [t2, this ](nic::InferContext * ctx, const std::shared_ptr<nic::InferContext::Request> & request) {
108+ [t2,this ](nic::InferContext* ctx, const std::shared_ptr<nic::InferContext::Request>& request) {
127109 // get results
128110 std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
129111 // this function interface will change in the next tensorrtis version
@@ -142,16 +124,17 @@ void TRTClientAsync::predictImpl()
142124
143125 edm::LogInfo (" TRTClient" ) << " Remote time: " << std::chrono::duration_cast<std::chrono::microseconds>(t3 - t2).count ();
144126
145- // check result
146- this ->getResults (results.begin ()->second );
147-
148127 // ServerSideStats stats;
149128 // SummarizeServerStats(std::make_pair(modelName_, -1), start_status, end_status, &stats);
150129 // ReportServerSideState(stats);
130+
131+ // check result
132+ this ->getResults (results.begin ()->second );
151133 }
152134 // finish
153135 this ->finish ();
154- });
136+ }
137+ );
155138}
156139
157140template <typename Client>
@@ -162,7 +145,7 @@ TRTClient<Client>::ReportServerSideState(const ServerSideStats& stats)
162145 const uint64_t cnt = stats.request_count ;
163146 if (cnt == 0 )
164147 {
165- std::cout << " Request count: " << cnt << std::endl ;
148+ edm::LogInfo ( " TRTClient " ) << " Request count: " << cnt;
166149 return ;
167150 }
168151
@@ -178,13 +161,12 @@ TRTClient<Client>::ReportServerSideState(const ServerSideStats& stats)
178161 const uint64_t overhead = (cumm_avg_us > queue_avg_us + compute_avg_us)
179162 ? (cumm_avg_us - queue_avg_us - compute_avg_us)
180163 : 0 ;
181- std::cout << " Request count: " << cnt << std::endl
164+ edm::LogInfo ( " TRTClient " ) << " Request count: " << cnt << std::endl
182165 << " Avg request latency: " << cumm_avg_us << " usec" ;
183166
184- std::cout << " (overhead " << overhead << " usec + "
167+ edm::LogInfo ( " TRTClient " ) << " (overhead " << overhead << " usec + "
185168 << " queue " << queue_avg_us << " usec + "
186- << " compute " << compute_avg_us << " usec)" << std::endl
187- << std::endl;
169+ << " compute " << compute_avg_us << " usec)" << std::endl;
188170}
189171
190172template <typename Client>
@@ -199,24 +181,6 @@ TRTClient<Client>::SummarizeServerStats(
199181 model_info.first , model_info.second ,
200182 start_status.find (model_info.first )->second ,
201183 end_status.find (model_info.first )->second , server_stats);
202-
203- // // Summarize the composing models, if any.
204- // for (const auto& composing_model_info : composing_models_map_[model_info]) {
205- // auto it = server_stats->composing_models_stat
206- // .emplace(composing_model_info, ServerSideStats())
207- // .first;
208- // if (composing_models_map_.find(composing_model_info) !=
209- // composing_models_map_.end()) {
210- // RETURN_IF_ERROR(SummarizeServerStats(
211- // composing_model_info, start_status, end_status, &(it->second)));
212- // } else {
213- // RETURN_IF_ERROR(SummarizeServerModelStats(
214- // composing_model_info.first, composing_model_info.second,
215- // start_status.find(composing_model_info.first)->second,
216- // end_status.find(composing_model_info.first)->second, &(it->second)));
217- // }
218-
219- // return nic::Error::Success;
220184}
221185
222186template <typename Client>
@@ -302,25 +266,6 @@ TRTClient<Client>::GetServerSideStatus(
302266 } else {
303267 model_status->emplace (model_info.first , itr->second );
304268 }
305-
306- // // Also get status for composing models if any
307- // for (const auto& composing_model_info : composing_models_map_[model_info]) {
308- // if (composing_models_map_.find(composing_model_info) !=
309- // composing_models_map_.end()) {
310- // GetServerSideStatus(
311- // server_status, composing_model_info, model_status);
312- // } else {
313- // const auto& itr =
314- // server_status.model_status().find(composing_model_info.first);
315- // if (itr == server_status.model_status().end()) {
316- // return nic::Error(
317- // ni::RequestStatusCode::INTERNAL,
318- // "unable to find status for composing model" +
319- // composing_model_info.first);
320- // } else {
321- // model_status->emplace(composing_model_info.first, itr->second);
322- // }
323- // }
324269}
325270
326271// explicit template instantiations
0 commit comments