Skip to content

Commit dd3c242

Browse files
authored
fix multi-thread exec of trt, test=develop (#19379)
1 parent 9048229 commit dd3c242

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

paddle/fluid/inference/tensorrt/engine.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,15 @@ void TensorRTEngine::Build(const DescType &paddle_model) {
3535
void TensorRTEngine::Execute(int batch_size, std::vector<void *> *buffers,
3636
cudaStream_t stream) {
3737
freshDeviceId();
38+
const std::thread::id tid = std::this_thread::get_id();
3839
batch_size_ = batch_size;
39-
infer_context_->enqueue(batch_size, buffers->data(), stream, nullptr);
40+
if (infer_context_.find(tid) == infer_context_.end()) {
41+
PADDLE_ENFORCE_NOT_NULL(
42+
infer_engine_,
43+
"You should build engine first and then set the context.");
44+
infer_context_[tid].reset(infer_engine_->createExecutionContext());
45+
}
46+
infer_context_[tid]->enqueue(batch_size, buffers->data(), stream, nullptr);
4047
cudaStreamSynchronize(stream);
4148
SetRuntimeBatch(batch_size);
4249
}
@@ -111,8 +118,6 @@ void TensorRTEngine::FreezeNetwork() {
111118

112119
infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_));
113120
PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!");
114-
115-
infer_context_.reset(infer_engine_->createExecutionContext());
116121
}
117122

118123
nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name,

paddle/fluid/inference/tensorrt/engine.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ class TensorRTEngine {
128128
&inference::Singleton<plugin::PluginFactoryTensorRT>::Global()));
129129
PADDLE_ENFORCE(infer_engine_ != nullptr,
130130
"build cuda engine failed when deserialize engine info.!");
131-
infer_context_.reset(infer_engine_->createExecutionContext());
132131
}
133132

134133
void SetRuntimeBatch(size_t batch_size);
@@ -200,7 +199,8 @@ class TensorRTEngine {
200199
infer_ptr<nvinfer1::IBuilder> infer_builder_;
201200
infer_ptr<nvinfer1::INetworkDefinition> infer_network_;
202201
infer_ptr<nvinfer1::ICudaEngine> infer_engine_;
203-
infer_ptr<nvinfer1::IExecutionContext> infer_context_;
202+
std::unordered_map<std::thread::id, infer_ptr<nvinfer1::IExecutionContext>>
203+
infer_context_;
204204
infer_ptr<nvinfer1::IHostMemory> ihost_memory_;
205205
std::unordered_map<nvinfer1::ITensor*, float> quant_dynamic_range_;
206206
}; // class TensorRTEngine

0 commit comments

Comments
 (0)