Skip to content

Commit 3ddd70d

Browse files
committed
expose the batch size of INT8 calibration as parameter, since different
size may generate different accuracy loss.
1 parent 63895f0 commit 3ddd70d

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch2trt/torch2trt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ def torch2trt(module,
392392
int8_mode=False,
393393
int8_calib_dataset=None,
394394
int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM,
395+
int8_calib_batch_size=1,
395396
use_onnx=False):
396397

397398
inputs_in = inputs
@@ -454,7 +455,7 @@ def torch2trt(module,
454455

455456
# @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption
456457
builder.int8_calibrator = DatasetCalibrator(
457-
inputs, int8_calib_dataset, batch_size=1, algorithm=int8_calib_algorithm
458+
inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm
458459
)
459460

460461
engine = builder.build_cuda_engine(network)

0 commit comments

Comments
 (0)