Skip to content

Commit dc4d966

Browse files
committed
docs(ptq): Adding a tutorial on how to use PTQ
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 06e796e commit dc4d966

File tree

1 file changed

+70
-5
lines changed

1 file changed

+70
-5
lines changed

docsrc/tutorials/ptq.rst

Lines changed: 70 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@ the TensorRT calibrator. With TRTorch we look to leverage existing infrastructur
1414
calibrators easier.
1515

1616
LibTorch provides a ``DataLoader`` and ``Dataset`` API which steamlines preprocessing and batching input data.
17-
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data.
17+
These APIs are exposed via both C++ and Python interface which makes it easier for the end user.
18+
For C++ interface, we use ``torch::Dataset`` and ``torch::data::make_data_loader`` objects to construct and perform pre-processing on datasets.
19+
The equivalent functionality in python interface uses ``torch.utils.data.Dataset`` and ``torch.utils.data.DataLoader``.
20+
This section of the PyTorch documentation has more information https://pytorch.org/tutorials/advanced/cpp_frontend.html#loading-data and https://pytorch.org/tutorials/recipes/recipes/loading_data_recipe.html.
1821
TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly
1922
implement a ``torch::Dataset`` for your target domain, place it in a DataLoader and create a INT8 Calibrator
2023
which you can provide to TRTorch to run INT8 Calibration during compliation of your module.
2124

22-
.. _writing_ptq:
25+
.. _writing_ptq_cpp:
2326

24-
How to create your own PTQ application
27+
How to create your own PTQ application in C++
2528
----------------------------------------
2629

2730
Here is an example interface of a ``torch::Dataset`` class for CIFAR10:
@@ -132,14 +135,76 @@ Then all thats required to setup the module for INT8 calibration is to set the f
132135
auto trt_mod = trtorch::CompileGraph(mod, compile_spec);
133136

134137
If you have an existing Calibrator implementation for TensorRT you may directly set the ``ptq_calibrator`` field with a pointer to your calibrator and it will work as well.
135-
136138
From here not much changes in terms of how to execution works. You are still able to fully use LibTorch as the sole interface for inference. Data should remain
137139
in FP32 precision when it's passed into `trt_mod.forward`. There exists an example application in the TRTorch demo that takes you from training a VGG16 network on
138140
CIFAR10 to deploying in INT8 with TRTorch here: https://github.com/NVIDIA/TRTorch/tree/master/cpp/ptq
139141

142+
.. _writing_ptq_python:
143+
144+
How to create your own PTQ application in Python
145+
----------------------------------------
146+
147+
TRTorch Python API provides an easy and convenient way to use pytorch dataloaders with TensorRT calibrators. ``DataLoaderCalibrator`` class can be used to create
148+
a TensorRT calibrator by providing desired configuration. The following code demonstrates an example on how to use it
149+
150+
.. code-block:: python
151+
152+
testing_dataset = torchvision.datasets.CIFAR10(root='./data',
153+
train=False,
154+
download=True,
155+
transform=transforms.Compose([
156+
transforms.ToTensor(),
157+
transforms.Normalize((0.4914, 0.4822, 0.4465),
158+
(0.2023, 0.1994, 0.2010))
159+
]))
160+
161+
testing_dataloader = torch.utils.data.DataLoader(testing_dataset,
162+
batch_size=1,
163+
shuffle=False,
164+
num_workers=1)
165+
calibrator = trtorch.ptq.DataLoaderCalibrator(testing_dataloader,
166+
cache_file='./calibration.cache',
167+
use_cache=False,
168+
algo_type=trtorch.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
169+
device=torch.device('cuda:0'))
170+
171+
compile_spec = {
172+
"input_shapes": [[1, 3, 32, 32]],
173+
"op_precision": torch.int8,
174+
"calibrator": calibrator,
175+
"device": {
176+
"device_type": trtorch.DeviceType.GPU,
177+
"gpu_id": 0,
178+
"dla_core": 0,
179+
"allow_gpu_fallback": False,
180+
"disable_tf32": False
181+
}
182+
}
183+
trt_mod = trtorch.compile(model, compile_spec)
184+
185+
In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
186+
to use ``CacheCalibrator`` to use in INT8 mode.
187+
188+
.. code-block:: python
189+
190+
calibrator = trtorch.ptq.CacheCalibrator("./calibration.cache")
191+
192+
compile_settings = {
193+
"input_shapes": [[1, 3, 32, 32]],
194+
"op_precision": torch.int8,
195+
"calibrator": calibrator,
196+
"max_batch_size": 32,
197+
}
198+
199+
trt_mod = trtorch.compile(model, compile_settings)
200+
201+
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
202+
For a demo on how PTQ can be performed on a VGG network using TRTorch API, you can refer to https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_dataloader_calibrator.py
203+
and https://github.com/NVIDIA/TRTorch/blob/master/tests/py/test_ptq_trt_calibrator.py
204+
140205
Citations
141206
^^^^^^^^^^^
142207

143208
Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images.
144209

145-
Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.
210+
Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556.

0 commit comments

Comments
 (0)