Skip to content

Commit b91c268

Browse files
authored
Merge pull request #578 from borisfom/trt7_back
Backwards compatibility for TRT 7.x
2 parents 429bcc1 + 6ebc1fb commit b91c268

File tree

15 files changed

+111
-107
lines changed

15 files changed

+111
-107
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,12 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5454
cudaSetDevice(settings.device.gpu_id) == cudaSuccess, "Unable to set gpu id: " << settings.device.gpu_id);
5555
}
5656

57-
builder = nvinfer1::createInferBuilder(logger);
58-
net = builder->createNetworkV2(1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH));
57+
builder = make_trt(nvinfer1::createInferBuilder(logger));
58+
net = make_trt(
59+
builder->createNetworkV2(1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)));
5960

6061
LOG_DEBUG(build_settings);
61-
cfg = builder->createBuilderConfig();
62+
cfg = make_trt(builder->createBuilderConfig());
6263

6364
for (auto p = settings.enabled_precisions.begin(); p != settings.enabled_precisions.end(); ++p) {
6465
switch (*p) {
@@ -91,11 +92,11 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
9192
if (settings.disable_tf32) {
9293
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
9394
}
94-
95+
#if NV_TENSORRT_MAJOR > 7
9596
if (settings.sparse_weights) {
9697
cfg->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
9798
}
98-
99+
#endif
99100
if (settings.refit) {
100101
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
101102
}
@@ -136,9 +137,6 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
136137
}
137138

138139
ConversionCtx::~ConversionCtx() {
139-
delete builder;
140-
delete net;
141-
delete cfg;
142140
for (auto ptr : builder_resources) {
143141
free(ptr);
144142
}
@@ -156,10 +154,19 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
156154
}
157155

158156
std::string ConversionCtx::SerializeEngine() {
157+
#if NV_TENSORRT_MAJOR > 7
159158
auto serialized_network = builder->buildSerializedNetwork(*net, *cfg);
160159
if (!serialized_network) {
161160
TRTORCH_THROW_ERROR("Building serialized network failed in TensorRT");
162161
}
162+
#else
163+
auto engine = builder->buildEngineWithConfig(*net, *cfg);
164+
if (!engine) {
165+
TRTORCH_THROW_ERROR("Building TensorRT engine failed");
166+
}
167+
auto serialized_network = engine->serialize();
168+
engine->destroy();
169+
#endif
163170
auto engine_str = std::string((const char*)serialized_network->data(), serialized_network->size());
164171
return engine_str;
165172
}

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ struct BuilderSettings {
3232
bool strict_types = false;
3333
bool truncate_long_and_double = false;
3434
Device device;
35-
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kSTANDARD;
35+
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3636
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3737
uint64_t num_min_timing_iters = 2;
3838
uint64_t num_avg_timing_iters = 1;
@@ -56,9 +56,9 @@ struct ConversionCtx {
5656
uint64_t num_inputs = 0;
5757
uint64_t num_outputs = 0;
5858
bool input_is_dynamic = false;
59-
nvinfer1::IBuilder* builder;
60-
nvinfer1::INetworkDefinition* net;
61-
nvinfer1::IBuilderConfig* cfg;
59+
std::shared_ptr<nvinfer1::IBuilder> builder;
60+
std::shared_ptr<nvinfer1::INetworkDefinition> net;
61+
std::shared_ptr<nvinfer1::IBuilderConfig> cfg;
6262
std::set<nvinfer1::DataType> enabled_precisions;
6363
BuilderSettings settings;
6464
util::logging::TRTorchLogger logger;

core/conversion/converters/Weights.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ struct Weights {
2525
} // namespace converters
2626
} // namespace conversion
2727
} // namespace core
28-
} // namespace trtorch
28+
} // namespace trtorch

core/conversion/converters/impl/interpolate.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,13 @@ void resize_layer_size(
108108

109109
resize_layer->setResizeMode(mode);
110110
resize_layer->setName(util::node_info(n).c_str());
111-
111+
#if NV_TENSORRT_MAJOR < 8
112+
resize_layer->setAlignCorners(align_corners);
113+
#else
112114
if (align_corners) {
113115
resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kALIGN_CORNERS);
114116
}
117+
#endif
115118
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
116119

117120
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());

core/conversion/converters/impl/quantization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace converters {
99
namespace impl {
1010
namespace {
1111

12+
#if NV_TENSORRT_MAJOR > 7
1213
// clang-format off
1314
auto quantization_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
1415
.pattern({"aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)",
@@ -53,6 +54,7 @@ auto quantization_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns(
5354
return true;
5455
}});
5556
// clang-format on
57+
#endif
5658
} // namespace
5759
} // namespace impl
5860
} // namespace converters

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,14 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
4343
device_info = most_compatible_device.value();
4444
set_cuda_device(device_info);
4545

46-
rt = std::shared_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(util::logging::get_logger()));
46+
rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));
4747

4848
name = slugify(mod_name);
4949

50-
cuda_engine = std::shared_ptr<nvinfer1::ICudaEngine>(
51-
rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
52-
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");
50+
cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
51+
TRTORCH_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");
5352

54-
exec_ctx = std::shared_ptr<nvinfer1::IExecutionContext>(cuda_engine->createExecutionContext());
53+
exec_ctx = make_trt(cuda_engine->createExecutionContext());
5554

5655
uint64_t inputs = 0;
5756
uint64_t outputs = 0;

core/util/trt_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ std::vector<int64_t> toVec(nvinfer1::Dims d) {
201201
for (int i = 0; i < d.nbDims; i++) {
202202
dims.push_back(d.d[i]);
203203
}
204-
return std::move(dims);
204+
return dims;
205205
}
206206

207207
std::string toStr(nvinfer1::Dims d) {

core/util/trt_util.h

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,30 @@
88

99
namespace nvinfer1 {
1010

11+
#if NV_TENSORRT_MAJOR < 8
12+
13+
#define TRT_ENGINE_CAPABILITY_STANDARD nvinfer1::EngineCapability::kDEFAULT
14+
#define TRT_ENGINE_CAPABILITY_SAFETY nvinfer1::EngineCapability::kSAFE_GPU
15+
#define TRT_ENGINE_CAPABILITY_DLA_STANDALONE nvinfer1::EngineCapability::kSAFE_DLA
16+
17+
template <class T>
18+
std::shared_ptr<T> make_trt(T* p) {
19+
return std::shared_ptr<T>(p, [](T* p) { p->destroy(); });
20+
}
21+
22+
#else
23+
24+
#define TRT_ENGINE_CAPABILITY_STANDARD nvinfer1::EngineCapability::kSTANDARD
25+
#define TRT_ENGINE_CAPABILITY_SAFETY nvinfer1::EngineCapability::kSAFETY
26+
#define TRT_ENGINE_CAPABILITY_DLA_STANDALONE nvinfer1::EngineCapability::kDLA_STANDALONE
27+
28+
template <class T>
29+
std::shared_ptr<T> make_trt(T* p) {
30+
return std::shared_ptr<T>(p);
31+
}
32+
33+
#endif
34+
1135
inline std::ostream& operator<<(std::ostream& os, const nvinfer1::TensorFormat& format) {
1236
switch (format) {
1337
case nvinfer1::TensorFormat::kLINEAR:
@@ -87,11 +111,11 @@ inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::DeviceType
87111

88112
inline std::ostream& operator<<(std::ostream& stream, const nvinfer1::EngineCapability& cap) {
89113
switch (cap) {
90-
case nvinfer1::EngineCapability::kSTANDARD:
114+
case TRT_ENGINE_CAPABILITY_STANDARD:
91115
return stream << "standard";
92-
case nvinfer1::EngineCapability::kSAFETY:
116+
case TRT_ENGINE_CAPABILITY_SAFETY:
93117
return stream << "safety";
94-
case nvinfer1::EngineCapability::kDLA_STANDALONE:
118+
case TRT_ENGINE_CAPABILITY_DLA_STANDALONE:
95119
return stream << "DLA standalone";
96120
default:
97121
return stream << "Unknown Engine Capability Setting";

cpp/api/src/compile_spec.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -388,14 +388,14 @@ core::CompileSpec to_internal_compile_spec(CompileSpec external) {
388388

389389
switch (external.capability) {
390390
case CompileSpec::EngineCapability::kSAFETY:
391-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSAFETY;
391+
internal.convert_info.engine_settings.capability = TRT_ENGINE_CAPABILITY_SAFETY;
392392
break;
393393
case CompileSpec::EngineCapability::kDLA_STANDALONE:
394-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kDLA_STANDALONE;
394+
internal.convert_info.engine_settings.capability = TRT_ENGINE_CAPABILITY_DLA_STANDALONE;
395395
break;
396396
case CompileSpec::EngineCapability::kSTANDARD:
397397
default:
398-
internal.convert_info.engine_settings.capability = nvinfer1::EngineCapability::kSTANDARD;
398+
internal.convert_info.engine_settings.capability = TRT_ENGINE_CAPABILITY_STANDARD;
399399
}
400400

401401
internal.convert_info.engine_settings.device.gpu_id = external.device.gpu_id;

docker/Dockerfile.20.06

Lines changed: 0 additions & 8 deletions
This file was deleted.

0 commit comments

Comments
 (0)