Skip to content

Commit bbcf2ca

Browse files
authored
Merge pull request #95 from NVIDIA/int8_mixed_precision_fix
Enable FP16 mixed precision with Int8
2 parents aa131ac + 3611778 commit bbcf2ca

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
1313
<< "\n Operating Precision: " << s.op_precision \
1414
<< "\n Make Refittable Engine: " << s.refit \
1515
<< "\n Debuggable Engine: " << s.debug \
16-
<< "\n Strict Type: " << s.strict_types \
16+
<< "\n Strict Types: " << s.strict_types \
1717
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
1818
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
1919
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
@@ -51,6 +51,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
5151
case nvinfer1::DataType::kINT8:
5252
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
5353
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
54+
if (!settings.strict_types) {
55+
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
56+
}
5457
input_type = nvinfer1::DataType::kFLOAT;
5558
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
5659
cfg->setInt8Calibrator(settings.calibrator);

0 commit comments

Comments
 (0)