Skip to content

Commit 210c160

Browse files
authored
Merge pull request #565 from NVIDIA/tensorrt_8_update
TensorRT 8.0 update (rebased for 1.9 master)
2 parents 3d9832e + 93c2a21 commit 210c160

29 files changed

+229
-226
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ These are the following dependencies used to verify the testcases. TRTorch can w
8181
- Libtorch 1.9.0 (built with CUDA 11.1)
8282
- CUDA 11.1 (10.2 on Jetson)
8383
- cuDNN 8.1
84-
- TensorRT 7.2.3
84+
- TensorRT 8.0.1.6
8585

8686
## Prebuilt Binaries and Wheel files
8787

WORKSPACE

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,20 +70,20 @@ http_archive(
7070
http_archive(
7171
name = "cudnn",
7272
build_file = "@//third_party/cudnn/archive:BUILD",
73-
sha256 = "98a8784e92862f20018d20c281b30d4a0cd951f93694f6433ccf4ae9c502ba6a",
73+
sha256 = "39412acd9ef5dd27954b6b9f5df75bd381c5d7ceb7979af6c743a7f4521f9c77",
7474
strip_prefix = "cuda",
7575
urls = [
76-
"https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.1.1.33/11.2_20210301/cudnn-11.2-linux-x64-v8.1.1.33.tgz",
76+
"https://developer.nvidia.com/compute/machine-learning/cudnn/secure/8.2.1.32/11.3_06072021/cudnn-11.3-linux-x64-v8.2.1.32.tgz",
7777
],
7878
)
7979

8080
http_archive(
8181
name = "tensorrt",
8282
build_file = "@//third_party/tensorrt/archive:BUILD",
83-
sha256 = "d3a1f478e304b48878604fac70ce7920fece71f9cac62f925c9c59c197f5d087",
84-
strip_prefix = "TensorRT-7.2.3.4",
83+
sha256 = "def6a5ee50bed25a68a9c9e22ec671a8f29ee5414bde47c5767bd279e5596f88",
84+
strip_prefix = "TensorRT-8.0.1.6",
8585
urls = [
86-
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/7.2.3/tars/TensorRT-7.2.3.4.Ubuntu-18.04.x86_64-gnu.cuda-11.1.cudnn8.1.tar.gz",
86+
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.0.1/tars/tensorrt-8.0.1.6.linux.x86_64-gnu.cuda-11.3.cudnn8.2.tar.gz",
8787
],
8888
)
8989

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
6969
case nvinfer1::DataType::kINT8:
7070
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
7171
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
72-
TRTORCH_CHECK(
73-
settings.calibrator != nullptr,
74-
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
75-
cfg->setInt8Calibrator(settings.calibrator);
72+
if (settings.calibrator == nullptr) {
73+
LOG_INFO(
74+
"INT8 kernels are enabled but not calibrator was provided, assuming source model was trained quantization aware");
75+
}
7676
break;
7777
case nvinfer1::DataType::kFLOAT:
7878
break;
@@ -90,6 +90,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
9090
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);
9191
}
9292

93+
if (settings.sparse_weights) {
94+
cfg->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
95+
}
96+
9397
if (settings.refit) {
9498
cfg->setFlag(nvinfer1::BuilderFlag::kREFIT);
9599
}
@@ -130,9 +134,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
130134
}
131135

132136
ConversionCtx::~ConversionCtx() {
133-
builder->destroy();
134-
net->destroy();
135-
cfg->destroy();
137+
delete builder;
138+
delete net;
139+
delete cfg;
136140
for (auto ptr : builder_resources) {
137141
free(ptr);
138142
}
@@ -150,14 +154,11 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
150154
}
151155

152156
std::string ConversionCtx::SerializeEngine() {
153-
auto engine = builder->buildEngineWithConfig(*net, *cfg);
154-
if (!engine) {
155-
TRTORCH_THROW_ERROR("Building TensorRT engine failed");
157+
auto serialized_network = builder->buildSerializedNetwork(*net, *cfg);
158+
if (!serialized_network) {
159+
TRTORCH_THROW_ERROR("Building serialized network failed in TensorRT");
156160
}
157-
auto serialized_engine = engine->serialize();
158-
engine->destroy();
159-
auto engine_str = std::string((const char*)serialized_engine->data(), serialized_engine->size());
160-
serialized_engine->destroy();
161+
auto engine_str = std::string((const char*)serialized_network->data(), serialized_network->size());
161162
return engine_str;
162163
}
163164

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@ struct Device {
2525

2626
struct BuilderSettings {
2727
std::set<nvinfer1::DataType> enabled_precisions = {nvinfer1::DataType::kFLOAT};
28-
std::vector<nvinfer1::DataType> input_dtypes;
28+
bool sparse_weights = false;
2929
bool disable_tf32 = false;
3030
bool refit = false;
3131
bool debug = false;
3232
bool strict_types = false;
3333
bool truncate_long_and_double = false;
3434
Device device;
35-
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
35+
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kSTANDARD;
3636
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3737
uint64_t num_min_timing_iters = 2;
3838
uint64_t num_avg_timing_iters = 1;

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ auto batch_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
4040
LOG_DEBUG("momentum disregarded");
4141
LOG_DEBUG("training disregarded");
4242
LOG_DEBUG("cudnn disregarded");
43+
TRTORCH_CHECK(orig_shape.nbDims > 2, "Unable to create batch normalization layer from node: " << *n);
4344

4445
// Expand spatial dims from 1D to 2D if needed
4546
bool expandDims = (orig_shape.nbDims < 4);

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
3030
LOG_DEBUG("out_padding: " << out_padding);
3131
LOG_DEBUG("groups: " << groups);
3232

33-
// Expand spatial dims from 1D to 2D if needed
33+
TRTORCH_CHECK(orig_dims.nbDims > 2, "Unable to create convolution layer from node: " << *n);
34+
3435
bool expandDims = (orig_dims.nbDims < 4);
3536
if (expandDims) {
3637
in = addPadding(ctx, n, in, 4);

core/conversion/converters/impl/interpolate.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,9 @@ void resize_layer_size(
109109
resize_layer->setResizeMode(mode);
110110
resize_layer->setName(util::node_info(n).c_str());
111111

112-
// if interpolation mode is linear, align corners must have been set to true.
113-
// else, don't use align corners.
114-
if (mode == nvinfer1::ResizeMode::kLINEAR) {
115-
#if NV_TENSORRT_MAJOR < 7 || (NV_TENSORRT_MAJOR == 7 && NV_TENSORRT_MINOR < 1) // IF TRT VERSION <= 7.0
116-
TRTORCH_CHECK(align_corners, "resize layer (linear) only supports align_corners=True in TensorRT <= 7.0");
117-
resize_layer->setAlignCorners(true);
118-
#else
119-
resize_layer->setAlignCorners(align_corners);
120-
#endif
112+
if (align_corners) {
113+
resize_layer->setCoordinateTransformation(nvinfer1::ResizeCoordinateTransformation::kALIGN_CORNERS);
121114
}
122-
123115
auto layer_output = ctx->AssociateValueAndTensor(n->outputs()[0], resize_layer->getOutput(0));
124116

125117
LOG_DEBUG("Output tensor shape: " << layer_output->getDimensions());

core/conversion/converters/impl/pooling.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ bool AdaptivePoolingConverter(
4848

4949
auto orig_dims = in->getDimensions();
5050
bool expandDims = (orig_dims.nbDims < 4);
51-
51+
TRTORCH_CHECK(orig_dims.nbDims > 2, "Unable to create pooling layer from node: " << *n);
5252
if (expandDims) {
5353
in = addPadding(ctx, n, in, 4, false, false);
5454
}
@@ -122,6 +122,7 @@ bool PoolingConverter(ConversionCtx* ctx, const torch::jit::Node* n, args& args,
122122

123123
// Max Pool needs at least 4D input
124124
auto orig_dims = in->getDimensions();
125+
TRTORCH_CHECK(orig_dims.nbDims > 2, "Unable to create pooling layer from node: " << *n);
125126
bool expandDims = (orig_dims.nbDims < 4);
126127

127128
if (expandDims) {

core/plugins/impl/interpolate_plugin.cpp

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -105,35 +105,35 @@ std::vector<int64_t> InterpolatePlugin::getOutputSize() {
105105
return size_;
106106
}
107107

108-
int InterpolatePlugin::getNbOutputs() const {
108+
int InterpolatePlugin::getNbOutputs() const noexcept {
109109
if (mode_ == "adaptive_max_pool2d") {
110110
return 2;
111111
} else {
112112
return 1;
113113
}
114114
}
115115

116-
const char* InterpolatePlugin::getPluginType() const {
116+
const char* InterpolatePlugin::getPluginType() const noexcept {
117117
return "Interpolate";
118118
}
119119

120-
const char* InterpolatePlugin::getPluginVersion() const {
120+
const char* InterpolatePlugin::getPluginVersion() const noexcept {
121121
return "1";
122122
}
123123

124-
const char* InterpolatePlugin::getPluginNamespace() const {
124+
const char* InterpolatePlugin::getPluginNamespace() const noexcept {
125125
return "trtorch";
126126
}
127127

128-
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const {
128+
nvinfer1::IPluginV2DynamicExt* InterpolatePlugin::clone() const noexcept {
129129
return new InterpolatePlugin(in_shape_, out_shape_, size_, scales_, mode_, align_corners_, use_scales_);
130130
}
131131

132132
nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
133133
int outputIndex,
134134
const nvinfer1::DimsExprs* inputs,
135135
int nbInputs,
136-
nvinfer1::IExprBuilder& exprBuilder) {
136+
nvinfer1::IExprBuilder& exprBuilder) noexcept {
137137
nvinfer1::DimsExprs output(inputs[0]);
138138

139139
// TODO: This should enable the case of using this plugin with dynamic shape, scale factor and align corners == true
@@ -165,15 +165,15 @@ nvinfer1::DimsExprs InterpolatePlugin::getOutputDimensions(
165165
}
166166

167167
nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs)
168-
const {
168+
const noexcept {
169169
return nvinfer1::DataType::kFLOAT;
170170
}
171171

172-
int InterpolatePlugin::initialize() {
172+
int InterpolatePlugin::initialize() noexcept {
173173
return 0;
174174
}
175175

176-
void InterpolatePlugin::serialize(void* buffer) const {
176+
void InterpolatePlugin::serialize(void* buffer) const noexcept {
177177
std::string data = serializeToString();
178178
size_t size = getSerializationSize();
179179

@@ -197,23 +197,32 @@ std::string InterpolatePlugin::serializeToString() const {
197197
return data_str.str();
198198
}
199199

200-
size_t InterpolatePlugin::getSerializationSize() const {
200+
size_t InterpolatePlugin::getSerializationSize() const noexcept {
201201
return serializeToString().size();
202202
}
203203

204204
bool InterpolatePlugin::supportsFormatCombination(
205205
int pos,
206206
const nvinfer1::PluginTensorDesc* inOut,
207207
int nbInputs,
208-
int nbOutputs) {
209-
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");
210-
208+
int nbOutputs) noexcept {
209+
if (nbInputs != 1) {
210+
LOG_ERROR("Expected a single tensor as input to interpolate plugin");
211+
}
211212
if (mode_ == "adaptive_max_pool2d") {
212-
TRTORCH_ASSERT(nbOutputs == 2, "Expected 2 tensors as output to interpolate plugin");
213-
TRTORCH_ASSERT(0 <= pos && pos <= 2, "There should be exactly 3 connections to the plugin - 1 input, 2 output");
213+
if (nbOutputs != 2) {
214+
LOG_ERROR("Expected 2 tensors as output to interpolate plugin");
215+
}
216+
if (pos < 0 || pos > 2) {
217+
LOG_ERROR("There should be exactly 3 connections to the plugin - 1 input, 2 output");
218+
}
214219
} else {
215-
TRTORCH_ASSERT(nbOutputs == 1, "Expected a single tensor as output to interpolate plugin");
216-
TRTORCH_ASSERT(0 <= pos && pos <= 1, "There should be exactly 2 connections to the plugin - 1 input, 1 output");
220+
if (nbOutputs != 1) {
221+
LOG_ERROR("Expected a single tensor as output to interpolate plugin");
222+
}
223+
if (pos < 0 || pos > 1) {
224+
LOG_ERROR("There should be exactly 2 connections to the plugin - 1 input, 1 output");
225+
}
217226
}
218227

219228
const nvinfer1::PluginTensorDesc& in = inOut[0];
@@ -232,15 +241,15 @@ void InterpolatePlugin::configurePlugin(
232241
const nvinfer1::DynamicPluginTensorDesc* in,
233242
int nbInputs,
234243
const nvinfer1::DynamicPluginTensorDesc* out,
235-
int nbOutputs) {
244+
int nbOutputs) noexcept {
236245
dtype_ = nvinfer1::DataType::kFLOAT;
237246
}
238247

239248
size_t InterpolatePlugin::getWorkspaceSize(
240249
const nvinfer1::PluginTensorDesc* inputs,
241250
int nbInputs,
242251
const nvinfer1::PluginTensorDesc* outputs,
243-
int nbOutputs) const {
252+
int nbOutputs) const noexcept {
244253
return 0;
245254
}
246255

@@ -250,7 +259,7 @@ int InterpolatePlugin::enqueue(
250259
const void* const* inputs,
251260
void* const* outputs,
252261
void* workspace,
253-
cudaStream_t stream) {
262+
cudaStream_t stream) noexcept {
254263
at::Tensor input =
255264
at::from_blob((void*)inputs[0], util::toVec(inputDesc->dims), [](void*) {}, {at::kCUDA}).to(torch::kFloat);
256265
at::Tensor output =
@@ -317,21 +326,21 @@ InterpolatePluginCreator::InterpolatePluginCreator() {
317326
mFC.fields = mPluginAttributes.data();
318327
}
319328

320-
const char* InterpolatePluginCreator::getPluginNamespace() const {
329+
const char* InterpolatePluginCreator::getPluginNamespace() const noexcept {
321330
return "trtorch";
322331
}
323332

324-
const char* InterpolatePluginCreator::getPluginName() const {
333+
const char* InterpolatePluginCreator::getPluginName() const noexcept {
325334
return "Interpolate";
326335
}
327336

328-
const char* InterpolatePluginCreator::getPluginVersion() const {
337+
const char* InterpolatePluginCreator::getPluginVersion() const noexcept {
329338
return "1";
330339
}
331340

332341
nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(
333342
const char* name,
334-
const nvinfer1::PluginFieldCollection* fc) {
343+
const nvinfer1::PluginFieldCollection* fc) noexcept {
335344
std::vector<int64_t> in_shape;
336345
std::vector<int64_t> out_shape;
337346
std::vector<int64_t> out_size;
@@ -370,12 +379,12 @@ nvinfer1::IPluginV2* InterpolatePluginCreator::createPlugin(
370379
nvinfer1::IPluginV2* InterpolatePluginCreator::deserializePlugin(
371380
const char* name,
372381
const void* serialData,
373-
size_t serialLength) {
382+
size_t serialLength) noexcept {
374383
name_ = name;
375384
return new InterpolatePlugin((const char*)serialData, serialLength);
376385
}
377386

378-
const nvinfer1::PluginFieldCollection* InterpolatePluginCreator::getFieldNames() {
387+
const nvinfer1::PluginFieldCollection* InterpolatePluginCreator::getFieldNames() noexcept {
379388
return nullptr;
380389
}
381390

@@ -384,4 +393,4 @@ REGISTER_TRTORCH_PLUGIN(InterpolatePluginCreator);
384393
} // namespace impl
385394
} // namespace plugins
386395
} // namespace core
387-
} // namespace trtorch
396+
} // namespace trtorch

0 commit comments

Comments
 (0)