@@ -89,28 +89,20 @@ def auto_quantize(
8989 qformat_list = qformat .split ("," )
9090 assert qformat_list , "No quantization formats provided"
9191 # Check if all provided quantization formats are supported
92- if args .export_fmt == "hf" :
93- assert all (
94- qformat
95- in [
96- "fp8" ,
97- "int4_awq" ,
98- "nvfp4" ,
99- "nvfp4_awq" ,
100- "w4a8_awq" ,
101- "fp8_pb_wo" ,
102- "w4a8_mxfp4_fp8" ,
103- "nvfp4_mlp_only" ,
104- ]
105- for qformat in qformat_list
106- ), (
107- "One or more quantization formats provided are not supported for unified checkpoint export"
108- )
109- else :
110- assert all (
111- qformat in ["fp8" , "int8_sq" , "int4_awq" , "w4a8_awq" , "nvfp4" , "nvfp4_awq" ]
112- for qformat in qformat_list
113- ), "One or more quantization formats provided are not supported for tensorrt llm export"
92+ assert all (
93+ qformat
94+ in [
95+ "fp8" ,
96+ "int4_awq" ,
97+ "nvfp4" ,
98+ "nvfp4_awq" ,
99+ "w4a8_awq" ,
100+ "fp8_pb_wo" ,
101+ "w4a8_mxfp4_fp8" ,
102+ "nvfp4_mlp_only" ,
103+ ]
104+ for qformat in qformat_list
105+ ), "One or more quantization formats provided are not supported for unified checkpoint export"
114106
115107 def loss_func (output , data ):
116108 # For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
@@ -219,27 +211,21 @@ def main(args):
219211 "Quantization supports only one quantization format."
220212 )
221213
222- # Check arguments for unified_hf export format and set to default if unsupported arguments are provided
223- if args .export_fmt == "hf" :
224- assert args .sparsity_fmt == "dense" , (
225- f"Sparsity format { args .sparsity_fmt } not supported by unified export api."
226- )
227-
228- if not args .auto_quantize_bits :
229- assert (
230- args .qformat
231- in [
232- "int4_awq" ,
233- "fp8" ,
234- "nvfp4" ,
235- "nvfp4_awq" ,
236- "w4a8_awq" ,
237- "fp8_pb_wo" ,
238- "w4a8_mxfp4_fp8" ,
239- "nvfp4_mlp_only" ,
240- ]
241- or args .kv_cache_qformat in KV_QUANT_CFG_CHOICES
242- ), f"Quantization format { args .qformat } not supported for HF export path"
214+ if not args .auto_quantize_bits :
215+ assert (
216+ args .qformat
217+ in [
218+ "int4_awq" ,
219+ "fp8" ,
220+ "nvfp4" ,
221+ "nvfp4_awq" ,
222+ "w4a8_awq" ,
223+ "fp8_pb_wo" ,
224+ "w4a8_mxfp4_fp8" ,
225+ "nvfp4_mlp_only" ,
226+ ]
227+ or args .kv_cache_qformat in KV_QUANT_CFG_CHOICES
228+ ), f"Quantization format { args .qformat } not supported for HF export path"
243229
244230 # If low memory mode is enabled, we compress the model while loading the HF checkpoint.
245231 calibration_only = False
@@ -253,9 +239,6 @@ def main(args):
253239 attn_implementation = args .attn_implementation ,
254240 )
255241 else :
256- assert args .export_fmt == "hf" , (
257- "Low memory mode is only supported for exporting HF checkpoint."
258- )
259242 assert args .qformat in QUANT_CFG_CHOICES , (
260243 f"Quantization format is not supported for low memory mode. Supported formats: { QUANT_CFG_CHOICES .keys ()} "
261244 )
@@ -600,7 +583,10 @@ def output_decode(generated_ids, input_shape):
600583 setattr (model .config , "architectures" , full_model_config .architectures )
601584
602585 start_time = time .time ()
603- if args .export_fmt == "tensorrt_llm" :
586+ if model_type in ["t5" , "bart" , "whisper" ] or args .sparsity_fmt != "dense" :
587+ # Still export TensorRT-LLM checkpoints for the models not supported by the
588+ # TensorRT-LLM torch runtime.
589+
604590 # Move meta tensor back to device before exporting.
605591 remove_hook_from_module (model , recurse = True )
606592
@@ -621,13 +607,16 @@ def output_decode(generated_ids, input_shape):
621607 inference_tensor_parallel = args .inference_tensor_parallel ,
622608 inference_pipeline_parallel = args .inference_pipeline_parallel ,
623609 )
624- elif args .export_fmt == "hf" :
610+ else :
611+ # Check arguments for unified_hf export format and set to default if unsupported arguments are provided
612+ assert args .sparsity_fmt == "dense" , (
613+ f"Sparsity format { args .sparsity_fmt } not supported by unified export api."
614+ )
615+
625616 export_hf_checkpoint (
626617 full_model ,
627618 export_dir = export_path ,
628619 )
629- else :
630- raise NotImplementedError (f"{ args .export_fmt } not supported" )
631620
632621 # Restore default padding and export the tokenizer as well.
633622 if tokenizer is not None :
@@ -707,13 +696,6 @@ def output_decode(generated_ids, input_shape):
707696 choices = KV_QUANT_CFG_CHOICES .keys (),
708697 help = "Specify KV cache quantization format, default to fp8 if not provided" ,
709698 )
710- parser .add_argument (
711- "--export_fmt" ,
712- required = False ,
713- default = "tensorrt_llm" ,
714- choices = ["tensorrt_llm" , "hf" ],
715- help = ("Checkpoint export format" ),
716- )
717699 parser .add_argument (
718700 "--trust_remote_code" ,
719701 help = "Set trust_remote_code for Huggingface models and tokenizers" ,
0 commit comments