Skip to content

Commit f5fb752

Browse files
authored
Merge pull request #398 from Chujingjun/master
expose the batch size of INT8 calibration as parameter
2 parents 48eb79e + 3ddd70d commit f5fb752

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torch2trt/torch2trt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ def torch2trt(module,
457457
int8_mode=False,
458458
int8_calib_dataset=None,
459459
int8_calib_algorithm=DEFAULT_CALIBRATION_ALGORITHM,
460+
int8_calib_batch_size=1,
460461
use_onnx=False):
461462

462463
inputs_in = inputs
@@ -519,7 +520,7 @@ def torch2trt(module,
519520

520521
# @TODO(jwelsh): Should we set batch_size=max_batch_size? Need to investigate memory consumption
521522
builder.int8_calibrator = DatasetCalibrator(
522-
inputs, int8_calib_dataset, batch_size=1, algorithm=int8_calib_algorithm
523+
inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm
523524
)
524525

525526
engine = builder.build_cuda_engine(network)

0 commit comments

Comments
 (0)