@@ -89,28 +89,20 @@ def auto_quantize(
89
89
qformat_list = qformat .split ("," )
90
90
assert qformat_list , "No quantization formats provided"
91
91
# Check if all provided quantization formats are supported
92
- if args .export_fmt == "hf" :
93
- assert all (
94
- qformat
95
- in [
96
- "fp8" ,
97
- "int4_awq" ,
98
- "nvfp4" ,
99
- "nvfp4_awq" ,
100
- "w4a8_awq" ,
101
- "fp8_pb_wo" ,
102
- "w4a8_mxfp4_fp8" ,
103
- "nvfp4_mlp_only" ,
104
- ]
105
- for qformat in qformat_list
106
- ), (
107
- "One or more quantization formats provided are not supported for unified checkpoint export"
108
- )
109
- else :
110
- assert all (
111
- qformat in ["fp8" , "int8_sq" , "int4_awq" , "w4a8_awq" , "nvfp4" , "nvfp4_awq" ]
112
- for qformat in qformat_list
113
- ), "One or more quantization formats provided are not supported for tensorrt llm export"
92
+ assert all (
93
+ qformat
94
+ in [
95
+ "fp8" ,
96
+ "int4_awq" ,
97
+ "nvfp4" ,
98
+ "nvfp4_awq" ,
99
+ "w4a8_awq" ,
100
+ "fp8_pb_wo" ,
101
+ "w4a8_mxfp4_fp8" ,
102
+ "nvfp4_mlp_only" ,
103
+ ]
104
+ for qformat in qformat_list
105
+ ), "One or more quantization formats provided are not supported for unified checkpoint export"
114
106
115
107
def loss_func (output , data ):
116
108
# For transformers AutoModelForCausalLM models, the outputs are wrapped in `CausalLMOutputWithPast`
@@ -219,27 +211,21 @@ def main(args):
219
211
"Quantization supports only one quantization format."
220
212
)
221
213
222
- # Check arguments for unified_hf export format and set to default if unsupported arguments are provided
223
- if args .export_fmt == "hf" :
224
- assert args .sparsity_fmt == "dense" , (
225
- f"Sparsity format { args .sparsity_fmt } not supported by unified export api."
226
- )
227
-
228
- if not args .auto_quantize_bits :
229
- assert (
230
- args .qformat
231
- in [
232
- "int4_awq" ,
233
- "fp8" ,
234
- "nvfp4" ,
235
- "nvfp4_awq" ,
236
- "w4a8_awq" ,
237
- "fp8_pb_wo" ,
238
- "w4a8_mxfp4_fp8" ,
239
- "nvfp4_mlp_only" ,
240
- ]
241
- or args .kv_cache_qformat in KV_QUANT_CFG_CHOICES
242
- ), f"Quantization format { args .qformat } not supported for HF export path"
214
+ if not args .auto_quantize_bits :
215
+ assert (
216
+ args .qformat
217
+ in [
218
+ "int4_awq" ,
219
+ "fp8" ,
220
+ "nvfp4" ,
221
+ "nvfp4_awq" ,
222
+ "w4a8_awq" ,
223
+ "fp8_pb_wo" ,
224
+ "w4a8_mxfp4_fp8" ,
225
+ "nvfp4_mlp_only" ,
226
+ ]
227
+ or args .kv_cache_qformat in KV_QUANT_CFG_CHOICES
228
+ ), f"Quantization format { args .qformat } not supported for HF export path"
243
229
244
230
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
245
231
calibration_only = False
@@ -253,9 +239,6 @@ def main(args):
253
239
attn_implementation = args .attn_implementation ,
254
240
)
255
241
else :
256
- assert args .export_fmt == "hf" , (
257
- "Low memory mode is only supported for exporting HF checkpoint."
258
- )
259
242
assert args .qformat in QUANT_CFG_CHOICES , (
260
243
f"Quantization format is not supported for low memory mode. Supported formats: { QUANT_CFG_CHOICES .keys ()} "
261
244
)
@@ -600,7 +583,10 @@ def output_decode(generated_ids, input_shape):
600
583
setattr (model .config , "architectures" , full_model_config .architectures )
601
584
602
585
start_time = time .time ()
603
- if args .export_fmt == "tensorrt_llm" :
586
+ if model_type in ["t5" , "bart" , "whisper" ] or args .sparsity_fmt != "dense" :
587
+ # Still export TensorRT-LLM checkpoints for the models not supported by the
588
+ # TensorRT-LLM torch runtime.
589
+
604
590
# Move meta tensor back to device before exporting.
605
591
remove_hook_from_module (model , recurse = True )
606
592
@@ -621,13 +607,16 @@ def output_decode(generated_ids, input_shape):
621
607
inference_tensor_parallel = args .inference_tensor_parallel ,
622
608
inference_pipeline_parallel = args .inference_pipeline_parallel ,
623
609
)
624
- elif args .export_fmt == "hf" :
610
+ else :
611
+ # Check arguments for unified_hf export format and set to default if unsupported arguments are provided
612
+ assert args .sparsity_fmt == "dense" , (
613
+ f"Sparsity format { args .sparsity_fmt } not supported by unified export api."
614
+ )
615
+
625
616
export_hf_checkpoint (
626
617
full_model ,
627
618
export_dir = export_path ,
628
619
)
629
- else :
630
- raise NotImplementedError (f"{ args .export_fmt } not supported" )
631
620
632
621
# Restore default padding and export the tokenizer as well.
633
622
if tokenizer is not None :
@@ -707,13 +696,6 @@ def output_decode(generated_ids, input_shape):
707
696
choices = KV_QUANT_CFG_CHOICES .keys (),
708
697
help = "Specify KV cache quantization format, default to fp8 if not provided" ,
709
698
)
710
- parser .add_argument (
711
- "--export_fmt" ,
712
- required = False ,
713
- default = "tensorrt_llm" ,
714
- choices = ["tensorrt_llm" , "hf" ],
715
- help = ("Checkpoint export format" ),
716
- )
717
699
parser .add_argument (
718
700
"--trust_remote_code" ,
719
701
help = "Set trust_remote_code for Huggingface models and tokenizers" ,
0 commit comments