@@ -219,7 +219,7 @@ def quantize(
219219 log_file : str | None = None ,
220220 trt_plugins : list [str ] | None = None ,
221221 trt_plugins_precision : list [str ] | None = None ,
222- high_precision_dtype : str | None = None ,
222+ high_precision_dtype : str = "fp16" ,
223223 mha_accumulation_dtype : str = "fp16" ,
224224 disable_mha_qdq : bool = False ,
225225 dq_only : bool = True ,
@@ -286,12 +286,13 @@ def quantize(
286286 Each item should have the format <op_type>:<precision>, where precision can be fp32 (default) or fp16.
287287 For example: op_type_1:fp16 op_type_2:fp32.
288288 high_precision_dtype:
289- High precision data type, one of ['fp32', 'fp16']. If high_precision_dtype == 'fp16', model's weight and
290- activation will be converted to fp16.
289+ High precision data type of the output model. If high_precision_dtype is 'fp16' or 'bf16'
290+ and the input model is of dtype fp32, model's weight and activation will be converted to
291+ 'fp16' or 'bf16'.
291292 mha_accumulation_dtype:
292- MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default.
293- If quantize_mode == 'fp8' and mha_accumulation_dtype == 'fp32', Cast nodes will be added to
294- MHA's bmm1 and bmm2's input and output tensors.
293+ MHA accumulation dtype. One of ['fp32', 'fp16']. 'fp16' by default. If quantize_mode == 'fp8' and
294+ mha_accumulation_dtype == 'fp32', Cast nodes will be added to MHA's bmm1 and bmm2's input
295+ and output tensors.
295296 disable_mha_qdq:
296297 Don't add Q/DQ layers to MatMuls in MHA pattern.
297298 dq_only:
@@ -461,7 +462,7 @@ def quantize(
461462 use_external_data_format = use_external_data_format ,
462463 intermediate_generated_files = intermediate_generated_files ,
463464 trt_extra_plugin_lib_paths = trt_plugins ,
464- high_precision_dtype = high_precision_dtype , # type: ignore[arg-type]
465+ high_precision_dtype = high_precision_dtype ,
465466 mha_accumulation_dtype = mha_accumulation_dtype ,
466467 passes = passes ,
467468 log_level = log_level ,
0 commit comments