|
| 1 | +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | +http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/inference/tensorrt/engine.h" |
| 16 | + |
| 17 | +#include <NvInfer.h> |
| 18 | +#include <cuda.h> |
| 19 | +#include <glog/logging.h> |
| 20 | +#include "paddle/fluid/inference/tensorrt/helper.h" |
| 21 | +#include "paddle/fluid/platform/enforce.h" |
| 22 | + |
| 23 | +namespace paddle { |
| 24 | +namespace inference { |
| 25 | +namespace tensorrt { |
| 26 | + |
| 27 | +void TensorRTEngine::Build(const DescType& paddle_model) { |
| 28 | + PADDLE_ENFORCE(false, "not implemented"); |
| 29 | +} |
| 30 | + |
| 31 | +void TensorRTEngine::Execute(int batch_size) { |
| 32 | + infer_context_->enqueue(batch_size, buffers_.data(), *stream_, nullptr); |
| 33 | + cudaStreamSynchronize(*stream_); |
| 34 | +} |
| 35 | + |
| 36 | +TensorRTEngine::~TensorRTEngine() { |
| 37 | + // clean buffer |
| 38 | + for (auto& buffer : buffers_) { |
| 39 | + if (buffer != nullptr) { |
| 40 | + PADDLE_ENFORCE_EQ(0, cudaFree(buffer)); |
| 41 | + buffer = nullptr; |
| 42 | + } |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +void TensorRTEngine::FreezeNetwork() { |
| 47 | + PADDLE_ENFORCE(infer_builder_ != nullptr, |
| 48 | + "Call InitNetwork first to initialize network."); |
| 49 | + PADDLE_ENFORCE(infer_network_ != nullptr, |
| 50 | + "Call InitNetwork first to initialize network."); |
| 51 | + // build engine. |
| 52 | + infer_builder_->setMaxBatchSize(max_batch_); |
| 53 | + infer_builder_->setMaxWorkspaceSize(max_workspace_); |
| 54 | + |
| 55 | + infer_engine_.reset(infer_builder_->buildCudaEngine(*infer_network_)); |
| 56 | + PADDLE_ENFORCE(infer_engine_ != nullptr, "build cuda engine failed!"); |
| 57 | + |
| 58 | + infer_context_.reset(infer_engine_->createExecutionContext()); |
| 59 | + |
| 60 | + // allocate GPU buffers. |
| 61 | + buffers_.resize(buffer_sizes_.size(), nullptr); |
| 62 | + for (auto& item : buffer_sizes_) { |
| 63 | + if (item.second == 0) { |
| 64 | + auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str()); |
| 65 | + item.second = kDataTypeSize[static_cast<int>( |
| 66 | + infer_engine_->getBindingDataType(slot_offset))] * |
| 67 | + AccumDims(infer_engine_->getBindingDimensions(slot_offset)); |
| 68 | + } |
| 69 | + PADDLE_ENFORCE_EQ(0, cudaMalloc(&buffer(item.first), item.second)); |
| 70 | + } |
| 71 | +} |
| 72 | + |
| 73 | +nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, |
| 74 | + nvinfer1::DataType dtype, |
| 75 | + const nvinfer1::Dims& dim) { |
| 76 | + PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s", |
| 77 | + name); |
| 78 | + |
| 79 | + PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); |
| 80 | + auto* input = infer_network_->addInput(name.c_str(), dtype, dim); |
| 81 | + PADDLE_ENFORCE(input, "infer network add input %s failed", name); |
| 82 | + |
| 83 | + buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim); |
| 84 | + return input; |
| 85 | +} |
| 86 | + |
| 87 | +void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer* layer, int offset, |
| 88 | + const std::string& name) { |
| 89 | + PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate output name %s", |
| 90 | + name); |
| 91 | + |
| 92 | + auto* output = layer->getOutput(offset); |
| 93 | + PADDLE_ENFORCE(output != nullptr); |
| 94 | + output->setName(name.c_str()); |
| 95 | + infer_network_->markOutput(*output); |
| 96 | + // output buffers' size can only be decided latter, set zero here to mark this |
| 97 | + // and will reset latter. |
| 98 | + buffer_sizes_[name] = 0; |
| 99 | +} |
| 100 | + |
| 101 | +void* TensorRTEngine::GetOutputInGPU(const std::string& name) { |
| 102 | + return buffer(name); |
| 103 | +} |
| 104 | + |
| 105 | +void TensorRTEngine::GetOutputInCPU(const std::string& name, void* dst, |
| 106 | + size_t max_size) { |
| 107 | + // determine data size |
| 108 | + auto it = buffer_sizes_.find(name); |
| 109 | + PADDLE_ENFORCE(it != buffer_sizes_.end()); |
| 110 | + PADDLE_ENFORCE_GT(it->second, 0); |
| 111 | + PADDLE_ENFORCE_GE(max_size, it->second); |
| 112 | + |
| 113 | + PADDLE_ENFORCE_EQ(0, cudaMemcpyAsync(dst, buffer(name), it->second, |
| 114 | + cudaMemcpyDeviceToHost, *stream_)); |
| 115 | +} |
| 116 | + |
| 117 | +void*& TensorRTEngine::buffer(const std::string& name) { |
| 118 | + PADDLE_ENFORCE(infer_engine_ != nullptr, "call FreezeNetwork first."); |
| 119 | + auto it = buffer_sizes_.find(name); |
| 120 | + PADDLE_ENFORCE(it != buffer_sizes_.end()); |
| 121 | + auto slot_offset = infer_engine_->getBindingIndex(name.c_str()); |
| 122 | + return buffers_[slot_offset]; |
| 123 | +} |
| 124 | + |
| 125 | +void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, |
| 126 | + size_t size) { |
| 127 | + void* buf = buffer(name); |
| 128 | + PADDLE_ENFORCE_EQ( |
| 129 | + 0, cudaMemcpyAsync(buf, data, size, cudaMemcpyHostToDevice, *stream_)); |
| 130 | +} |
| 131 | + |
| 132 | +} // namespace tensorrt |
| 133 | +} // namespace inference |
| 134 | +} // namespace paddle |
0 commit comments