Skip to content

Commit 12a1739

Browse files
Update ptq.rst
Signed-off-by: Nick Comly <[email protected]>
1 parent d0e471f commit 12a1739

File tree

1 file changed

+13
-20
lines changed

1 file changed

+13
-20
lines changed

docsrc/tutorials/ptq.rst

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,16 @@ a TensorRT calibrator by providing desired configuration. The following code dem
167167
algo_type=torch_tensorrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
168168
device=torch.device('cuda:0'))
169169
170-
compile_spec = {
171-
"inputs": [torch_tensorrt.Input((1, 3, 32, 32))],
172-
"enabled_precisions": {torch.float, torch.half, torch.int8},
173-
"calibrator": calibrator,
174-
"device": {
175-
"device_type": torch_tensorrt.DeviceType.GPU,
176-
"gpu_id": 0,
177-
"dla_core": 0,
178-
"allow_gpu_fallback": False,
179-
"disable_tf32": False
180-
}
181-
}
182-
trt_mod = torch_tensorrt.compile(model, compile_spec)
170+
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 32, 32))],
171+
enabled_precisions={torch.float, torch.half, torch.int8},
172+
calibrator=calibrator,
173+
device={
174+
"device_type": torch_tensorrt.DeviceType.GPU,
175+
"gpu_id": 0,
176+
"dla_core": 0,
177+
"allow_gpu_fallback": False,
178+
"disable_tf32": False
179+
})
183180
184181
In the cases where there is a pre-existing calibration cache file that users want to use, ``CacheCalibrator`` can be used without any dataloaders. The following example demonstrates how
185182
to use ``CacheCalibrator`` to use in INT8 mode.
@@ -188,13 +185,9 @@ to use ``CacheCalibrator`` to use in INT8 mode.
188185
189186
calibrator = torch_tensorrt.ptq.CacheCalibrator("./calibration.cache")
190187
191-
compile_settings = {
192-
"inputs": [torch_tensorrt.Input([1, 3, 32, 32])],
193-
"enabled_precisions": {torch.float, torch.half, torch.int8},
194-
"calibrator": calibrator,
195-
}
196-
197-
trt_mod = torch_tensorrt.compile(model, compile_settings)
188+
trt_mod = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input([1, 3, 32, 32])],
189+
enabled_precisions={torch.float, torch.half, torch.int8},
190+
calibrator=calibrator)
198191
199192
If you already have an existing calibrator class (implemented directly using TensorRT API), you can directly set the calibrator field to your class which can be very convenient.
200193
For a demo on how PTQ can be performed on a VGG network using Torch-TensorRT API, you can refer to https://github.com/pytorch/TensorRT/blob/master/tests/py/test_ptq_dataloader_calibrator.py

0 commit comments

Comments
 (0)