Skip to content

Commit f0aab14

Browse files
authored
Merge pull request #531 from NVIDIA/qat
Enable QAT functionality of TRT 8.0 in TRTorch
2 parents 6408389 + 15f9205 commit f0aab14

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1344
-129
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ tests/py/data
4444
examples/**/deps/**/*
4545
!examples/**/deps/.gitkeep
4646
examples/trtorchrt_example/trtorchrt_example
47+
examples/int8/ptq/ptq
48+
examples/int8/qat/qat
49+
examples/int8/training/vgg16/data/*
50+
examples/int8/datasets/data/*

core/compiler.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ void AddEngineToGraph(
119119
}
120120

121121
bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name) {
122-
// Go through Lowering to simplify graph and extract weight parameters
123-
auto graph_and_parameters = lowering::Lower(mod, method_name);
122+
// Go through Lowering to simplify graph
123+
auto graph_and_parameters = lowering::Lower(mod, method_name, lowering::LowerInfo());
124124

125125
auto g = graph_and_parameters.first;
126126
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");
@@ -130,7 +130,7 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::stri
130130

131131
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
132132
// Go through Lowering to simplify graph and extract weight parameters
133-
auto graph_and_parameters = lowering::Lower(mod, method_name);
133+
auto graph_and_parameters = lowering::Lower(mod, method_name, cfg.lower_info);
134134

135135
auto convert_cfg = std::move(cfg.convert_info);
136136
auto g = graph_and_parameters.first;
@@ -309,7 +309,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
309309
// Compile only forward methods. forward method contains the entire graph.
310310
if (method.name().compare("forward") == 0) {
311311
auto new_g = std::make_shared<torch::jit::Graph>();
312-
auto graph_and_parameters = lowering::Lower(mod, method.name());
312+
auto graph_and_parameters = lowering::Lower(mod, method.name(), cfg.lower_info);
313313

314314
auto g = graph_and_parameters.first;
315315
auto params = graph_and_parameters.second;

core/compiler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <vector>
55
#include "core/conversion/conversion.h"
66
#include "core/ir/ir.h"
7+
#include "core/lowering/lowering.h"
78
#include "core/partitioning/partitioning.h"
89
#include "core/runtime/runtime.h"
910
#include "torch/csrc/jit/api/module.h"
@@ -14,6 +15,7 @@ namespace core {
1415
struct CompileSpec {
1516
CompileSpec(std::vector<ir::Input> inputs) : convert_info(std::move(inputs)) {}
1617
conversion::ConversionInfo convert_info;
18+
lowering::LowerInfo lower_info;
1719
partitioning::PartitionInfo partition_info;
1820
};
1921

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
6969
case nvinfer1::DataType::kINT8:
7070
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
7171
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
72-
if (settings.calibrator == nullptr) {
72+
if (!settings.calibrator) {
7373
LOG_INFO(
74-
"INT8 kernels are enabled but not calibrator was provided, assuming source model was trained quantization aware");
74+
"Int8 precision has been enabled but no calibrator provided. This assumes the network has Q/DQ nodes obtained from Quantization aware training. For more details, refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks");
75+
} else {
76+
cfg->setInt8Calibrator(settings.calibrator);
7577
}
7678
break;
7779
case nvinfer1::DataType::kFLOAT:

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ cc_library(
6969
"impl/matrix_multiply.cpp",
7070
"impl/normalize.cpp",
7171
"impl/pooling.cpp",
72+
"impl/quantization.cpp",
7273
"impl/reduce.cpp",
7374
"impl/replication_pad.cpp",
7475
"impl/select.cpp",

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 84 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,97 @@ namespace impl {
1111
namespace {
1212

1313
bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args) {
14-
auto in = args[0].ITensor(); // assumes non-static input Tensor
15-
auto w = Weights(ctx, args[1].unwrapToTensor());
14+
// Input to conv/deconv
15+
auto in = args[0].ITensor();
16+
17+
// Conv /deconv parameters
1618
auto stride = util::toDims(args[3].unwrapToIntList());
1719
auto padding = util::toDims(args[4].unwrapToIntList());
1820
auto dilation = util::toDims(args[5].unwrapToIntList());
1921
bool transposed = args[6].unwrapToBool();
2022
auto out_padding = util::toDims(args[7].unwrapToIntList());
2123
int64_t groups = args[8].unwrapToInt();
2224

25+
// Reshape the parameters to 2D if needed
26+
if (stride.nbDims == 1) {
27+
stride = util::unsqueezeDims(stride, 1, 1);
28+
LOG_DEBUG("Reshaped stride: " << stride);
29+
}
30+
if (dilation.nbDims == 1) {
31+
dilation = util::unsqueezeDims(dilation, 1, 1);
32+
LOG_DEBUG("Reshaped dilation: " << dilation);
33+
}
34+
if (padding.nbDims == 1) {
35+
padding = util::unsqueezeDims(padding, 1, 0);
36+
LOG_DEBUG("Reshaped padding: " << padding);
37+
}
38+
if (out_padding.nbDims == 1) {
39+
out_padding = util::unsqueezeDims(out_padding, 1, 0);
40+
LOG_DEBUG("Reshaped out_padding: " << out_padding);
41+
}
42+
43+
// Get bias tensor or initialize it to zeros.
44+
Weights bias;
45+
if (args[2].IValue()->isTensor()) {
46+
bias = Weights(ctx, args[2].unwrapToTensor());
47+
} else {
48+
bias = Weights();
49+
}
50+
51+
// Handle case when weights of conv/deconv is an ITensor. This case happens for QAT networks where
52+
// conv_weights -> Quantize -> Dequantize -> new_conv_weights -> conv <- input
53+
// new_conv_weights will be an ITensor because it is an output of Dequantize layer defined in impl/quantization.cpp
54+
if (args[1].isITensor()) {
55+
// Get the kernel tensor
56+
auto kernel = args[1].ITensor();
57+
auto kernel_dims = kernel->getDimensions();
58+
59+
// Make a new Dims with only the spatial dimensions.
60+
nvinfer1::Dims filter_dim;
61+
int64_t nbSpatialDims = in->getDimensions().nbDims - 2;
62+
TRTORCH_CHECK(
63+
nbSpatialDims = kernel_dims.nbDims - 2,
64+
"Number of input spatial dimensions should match the kernel spatial dimensions");
65+
filter_dim.nbDims = nbSpatialDims;
66+
filter_dim.d[0] = kernel_dims.d[2];
67+
filter_dim.d[1] = kernel_dims.d[3];
68+
69+
// Initialize a dummy constant kernel to pass it to INetwork->addConvolutionNd/addDeconvolutionNd API.
70+
auto kernel_weights = nvinfer1::Weights{nvinfer1::DataType::kFLOAT, nullptr, 0};
71+
72+
nvinfer1::ILayer* layer = nullptr;
73+
if (transposed) {
74+
nvinfer1::IDeconvolutionLayer* deconvLayer =
75+
ctx->net->addDeconvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
76+
deconvLayer->setStrideNd(stride);
77+
deconvLayer->setDilationNd(dilation);
78+
deconvLayer->setNbGroups(groups);
79+
deconvLayer->setPaddingNd(padding);
80+
// Set deconv kernel weights
81+
deconvLayer->setInput(1, *kernel);
82+
TRTORCH_CHECK(deconvLayer, "Unable to create deconv layer with non-const weights from node: " << *n);
83+
layer = deconvLayer;
84+
} else {
85+
nvinfer1::IConvolutionLayer* convLayer =
86+
ctx->net->addConvolutionNd(*in, kernel_dims.d[0], filter_dim, kernel_weights, bias.data);
87+
convLayer->setStrideNd(stride);
88+
convLayer->setPaddingMode(nvinfer1::PaddingMode::kCAFFE_ROUND_DOWN);
89+
convLayer->setPaddingNd(padding);
90+
convLayer->setPostPadding(out_padding);
91+
convLayer->setDilationNd(dilation);
92+
convLayer->setNbGroups(groups);
93+
94+
// Set conv kernel weights
95+
convLayer->setInput(1, *kernel);
96+
layer = convLayer;
97+
}
98+
99+
ctx->AssociateValueAndTensor(n->outputs()[0], layer->getOutput(0));
100+
LOG_DEBUG("Output tensor shape: " << layer->getOutput(0)->getDimensions());
101+
return true;
102+
}
103+
104+
auto w = Weights(ctx, args[1].unwrapToTensor());
23105
auto dims = in->getDimensions();
24106
auto orig_dims = dims;
25107
LOG_DEBUG("Input dims: " << orig_dims);
@@ -47,32 +129,9 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
47129
w.kernel_shape.d[1] = 1;
48130
LOG_DEBUG("Reshaped Weights: " << w);
49131
}
50-
if (stride.nbDims == 1) {
51-
stride = util::unsqueezeDims(stride, 1, 1);
52-
LOG_DEBUG("Reshaped stride: " << stride);
53-
}
54-
if (dilation.nbDims == 1) {
55-
dilation = util::unsqueezeDims(dilation, 1, 1);
56-
LOG_DEBUG("Reshaped dilation: " << dilation);
57-
}
58-
if (padding.nbDims == 1) {
59-
padding = util::unsqueezeDims(padding, 1, 0);
60-
LOG_DEBUG("Reshaped padding: " << padding);
61-
}
62-
if (out_padding.nbDims == 1) {
63-
out_padding = util::unsqueezeDims(out_padding, 1, 0);
64-
LOG_DEBUG("Reshaped out_padding: " << out_padding);
65-
}
66132

67133
nvinfer1::ILayer* new_layer;
68134
if (transposed) {
69-
Weights bias;
70-
if (args[2].IValue()->isTensor()) {
71-
bias = Weights(ctx, args[2].unwrapToTensor());
72-
} else {
73-
bias = Weights(ctx, torch::zeros(w.shape.d[1] * groups));
74-
}
75-
76135
// shape of deconvolution's weight: [in, out/groups, ...]
77136
auto deconv = ctx->net->addDeconvolutionNd(*in, w.shape.d[1] * groups, w.kernel_shape, w.data, bias.data);
78137
TRTORCH_CHECK(deconv, "Unable to create deconvolution layer from node: " << *n);
@@ -90,13 +149,6 @@ bool add_conv_deconv(ConversionCtx* ctx, const torch::jit::Node* n, args& args)
90149
#endif
91150
new_layer = deconv;
92151
} else {
93-
Weights bias;
94-
if (args[2].IValue()->isTensor()) {
95-
bias = Weights(ctx, args[2].unwrapToTensor());
96-
} else {
97-
bias = Weights(ctx, torch::zeros(w.shape.d[0]));
98-
}
99-
100152
// shape of convolution's weight: [out, in/groups, ...]
101153
auto conv = ctx->net->addConvolutionNd(*in, w.shape.d[0], w.kernel_shape, w.data, bias.data);
102154
TRTORCH_CHECK(conv, "Unable to create convolution layer from node: " << *n);

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ auto mm_registrations TRTORCH_UNUSED =
2525

2626
auto mm_layer = ctx->net->addMatrixMultiply(
2727
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
28+
2829
TRTORCH_CHECK(mm_layer, "Unable to create matrix multiplication node: " << *n);
2930
mm_layer->setName(util::node_info(n).c_str());
3031
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], mm_layer->getOutput(0));
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include <torch/torch.h>
2+
#include "core/conversion/converters/converters.h"
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace conversion {
8+
namespace converters {
9+
namespace impl {
10+
namespace {
11+
12+
// clang-format off
13+
auto quantization_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns()
14+
.pattern({"aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)",
15+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
16+
// This aten operator is generated from torch.fake_quantize_per_tensor_affine op in Pytorch python API.
17+
// Example usage: https://github.com/pytorch/pytorch/blob/master/torch/quantization/fake_quantize.py#L145
18+
auto input = args[0].ITensorOrFreeze(ctx);
19+
auto scale = args[1].unwrapToScalar().to<float>();
20+
auto scaleTensor = tensor_to_const(ctx, torch::tensor({scale}));
21+
// Add and configure a QuantizeLayer.
22+
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scaleTensor);
23+
quantize_layer->setAxis(0);
24+
25+
// Add and configure DequantizeLayer following a QuantizeLayer
26+
nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scaleTensor);
27+
dequantize_layer->setAxis(0);
28+
29+
auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
30+
LOG_DEBUG("[fake_quantize_per_tensor_affine] Output tensor shape: " << qdq_out->getDimensions());
31+
32+
return true;
33+
}})
34+
.pattern({"aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)",
35+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
36+
// This aten operator is generated from torch.fake_quantize_per_channel_affine op in Pytorch python API.
37+
// Example usage: https://github.com/pytorch/pytorch/blob/master/torch/quantization/fake_quantize.py#L141
38+
auto input = args[0].ITensorOrFreeze(ctx);
39+
auto scale = args[1].ITensorOrFreeze(ctx);
40+
int64_t axis = args[3].unwrapToScalar().to<int64_t>();
41+
// Add and configure a QuantizeLayer.
42+
nvinfer1::IQuantizeLayer* quantize_layer = ctx->net->addQuantize(*input, *scale);
43+
// Set a channel axis which represents output channels
44+
quantize_layer->setAxis(axis);
45+
46+
// Add and configure a DequantizeLayer.
47+
nvinfer1::IDequantizeLayer* dequantize_layer = ctx->net->addDequantize(*quantize_layer->getOutput(0), *scale);
48+
dequantize_layer->setAxis(axis);
49+
auto qdq_out = ctx->AssociateValueAndTensor(n->outputs()[0], dequantize_layer->getOutput(0));
50+
51+
LOG_DEBUG("[fake_quantize_per_channel_affine] Ouput tensor shape: " << qdq_out->getDimensions());
52+
53+
return true;
54+
}});
55+
// clang-format on
56+
} // namespace
57+
} // namespace impl
58+
} // namespace converters
59+
} // namespace conversion
60+
} // namespace core
61+
} // namespace trtorch

core/conversion/evaluators/aten.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,22 @@ auto aten_registrations TRTORCH_UNUSED =
143143
auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
144144
return out_tensor;
145145
}})
146+
.evaluator({c10::Symbol::fromQualString("aten::full"),
147+
// aten::full(int[] size, Scalar fill_value, *, int? dtype=None, int? layout=None,
148+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
149+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
150+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
151+
152+
// Input 2 here is the dtype
153+
if (!args.at(n->input(2)).isNone() && !args.at(n->input(2)).IValue()->isNone()) {
154+
options = options.dtype(c10::ScalarType(args.at(n->input(2)).unwrapToInt()));
155+
}
156+
157+
auto scalar_value = args.at(n->input(1)).unwrapToScalar().to<float>();
158+
auto out_tensor =
159+
torch::full(args.at(n->input(0)).unwrapToIntList().vec(), scalar_value, options);
160+
return out_tensor;
161+
}})
146162
.evaluator({c10::Symbol::fromQualString("aten::slice"),
147163
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
148164
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();

core/lowering/lowering.cpp

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ void LowerBlock(torch::jit::Block* b) {
2424
DropUnusedNodes(b);
2525
}
2626

27-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
27+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
2828
passes::UnpackHardSwish(g);
2929
torch::jit::EliminateRedundantGuards(g);
3030
torch::jit::RemoveListMutation(g);
@@ -43,9 +43,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4343
passes::Conv3DToConvolution(g);
4444
passes::FuseAddMMBranches(g);
4545
passes::RemoveBNDimCheck(g);
46-
torch::jit::EliminateCommonSubexpression(g);
46+
if (!lower_info.disable_cse) {
47+
torch::jit::EliminateCommonSubexpression(g);
48+
}
4749
// torch::jit::UnrollLoops(g);
48-
torch::jit::EliminateCommonSubexpression(g);
4950
passes::UnpackAddMM(g);
5051
// passes::UnpackBatchNorm(g);
5152
passes::UnpackLogSoftmax(g);
@@ -59,26 +60,32 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
5960
}
6061

6162
torch::jit::Module LowerModule(const torch::jit::script::Module& mod) {
63+
LOG_DEBUG("Input module is being frozen by torch::jit::freeze_module");
6264
auto mod_ = torch::jit::freeze_module(mod);
6365
return mod_;
6466
}
6567

6668
std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> Lower(
6769
const torch::jit::script::Module& mod,
68-
std::string method_name) {
69-
auto lowered_mod = LowerModule(mod);
70+
std::string method_name,
71+
LowerInfo lower_info) {
72+
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod);
7073
auto g = lowered_mod.get_method(method_name).graph();
7174
LOG_GRAPH(*g);
7275

73-
// Go through TRTorch Lowering to reformat graph to be conversion friendly
74-
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
75-
LOG_GRAPH("TRTorch Graph Lowering");
76-
lowering::LowerGraph(g);
77-
//=[torch::jit::FoldConvBatchNorm2d(lowered_mod);
7876
LOG_GRAPH("LibTorch Lowering");
7977
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
78+
79+
// Go through TRTorch Lowering to reformat graph to be conversion friendly
80+
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU , PYT)
81+
// unfreeze_module is used to not perform constant folding on weights in the network.
82+
// In quantization aware trained (QAT) models, weights are passed through quantize and
83+
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
84+
LOG_GRAPH("TRTorch Graph Lowering");
85+
lowering::LowerGraph(graph_and_ivalues.first, lower_info);
86+
8087
// Is this necessary?
81-
lowering::LowerBlock(g->block());
88+
// lowering::LowerBlock(g->block());
8289

8390
return graph_and_ivalues;
8491
}

0 commit comments

Comments
 (0)