@@ -47,9 +47,9 @@ def llm_arguments():
4747 """Parse the arguments for the llm export script."""
4848 parser = argparse .ArgumentParser ()
4949 parser .add_argument (
50- "--torch_dir " ,
50+ "--hf_model_path " ,
5151 type = str ,
52- help = "The folder of HF PyTorch model ckpt or HuggingFace model name/path (e.g., 'Qwen/Qwen2.5 -0.5B-Instruct ')" ,
52+ help = "The folder of HF PyTorch model ckpt or HuggingFace model name/path (e.g., 'Qwen/Qwen3 -0.6B ')" ,
5353 required = False ,
5454 )
5555 parser .add_argument (
@@ -110,34 +110,34 @@ def llm_arguments():
110110def get_config_path (args ):
111111 """
112112 Get config.json file path from the arguments.
113- The default priority is: config_path > torch_dir /config.json > onnx_path/../config.json
113+ The default priority is: config_path > hf_model_path /config.json > onnx_path/../config.json
114114 """
115115 if args .config_path and os .path .exists (args .config_path ):
116116 return args .config_path
117- if args .torch_dir :
118- # Check if torch_dir is a local directory
119- if os .path .isdir (args .torch_dir ):
120- torch_config = os .path .join (args .torch_dir , "config.json" )
117+ if args .hf_model_path :
118+ # Check if hf_model_path is a local directory
119+ if os .path .isdir (args .hf_model_path ):
120+ torch_config = os .path .join (args .hf_model_path , "config.json" )
121121 if os .path .exists (torch_config ):
122122 return torch_config
123123 else :
124124 # For HuggingFace model names, download config temporarily
125125 try :
126126 # Download config from HuggingFace
127127 config = AutoConfig .from_pretrained (
128- args .torch_dir , trust_remote_code = args .trust_remote_code
128+ args .hf_model_path , trust_remote_code = args .trust_remote_code
129129 )
130130
131131 # Save to temporary file
132132 temp_config_path = os .path .join (
133- tempfile .gettempdir (), f"config_{ args .torch_dir .replace ('/' , '_' )} .json"
133+ tempfile .gettempdir (), f"config_{ args .hf_model_path .replace ('/' , '_' )} .json"
134134 )
135135 with open (temp_config_path , "w" ) as f :
136136 json .dump (config .to_dict (), f , indent = 2 )
137137
138138 return temp_config_path
139139 except Exception as e :
140- print (f"Warning: Could not download config for { args .torch_dir } : { e } " )
140+ print (f"Warning: Could not download config for { args .hf_model_path } : { e } " )
141141
142142 if args .onnx_path :
143143 onnx_config = os .path .join (os .path .dirname (args .onnx_path ), "config.json" )
@@ -152,7 +152,7 @@ def export_raw_llm(
152152 output_dir ,
153153 dtype ,
154154 config_path ,
155- torch_dir ,
155+ hf_model_path ,
156156 lm_head_precision = "fp16" ,
157157 dataset_dir = "" ,
158158 wrapper_cls = WrapperModelForCausalLM ,
@@ -167,7 +167,7 @@ def export_raw_llm(
167167 output_dir: str
168168 dtype: str
169169 config_path: str
170- torch_dir : str, Used for loading tokenizer for quantization
170+ hf_model_path : str, Used for loading tokenizer for quantization
171171 dataset_dir: str, Used for quantization
172172 wrapper_cls: class, Used for wrapping the model
173173 extra_inputs: dict, Used for extra inputs
@@ -187,11 +187,11 @@ def export_raw_llm(
187187 # Need to quantize model to fp8, int4_awq or nvfp4
188188 if dtype in ["fp8" , "int4_awq" , "nvfp4" ]:
189189 tokenizer = AutoTokenizer .from_pretrained (
190- torch_dir , trust_remote_code = args .trust_remote_code
190+ hf_model_path , trust_remote_code = args .trust_remote_code
191191 )
192- # Only check for local modelopt_state if torch_dir is a local directory
193- if os .path .isdir (torch_dir ):
194- modelopt_state = os .path .join (torch_dir , "modelopt_state.pth" )
192+ # Only check for local modelopt_state if hf_model_path is a local directory
193+ if os .path .isdir (hf_model_path ):
194+ modelopt_state = os .path .join (hf_model_path , "modelopt_state.pth" )
195195 model_needs_quantization = not os .path .exists (modelopt_state )
196196 else :
197197 # For HuggingFace model names, always quantize as we can't have local state files
@@ -345,8 +345,8 @@ def get_modelopt_version():
345345
346346def main (args ):
347347 """Main function to export the LLM model to ONNX."""
348- assert args .torch_dir or args .onnx_path , (
349- "You need to provide either --torch_dir or --onnx_path to process the export script."
348+ assert args .hf_model_path or args .onnx_path , (
349+ "You need to provide either --hf_model_path or --onnx_path to process the export script."
350350 )
351351 start_time = time .time ()
352352
@@ -356,14 +356,11 @@ def main(args):
356356 if args .onnx_path :
357357 raw_onnx_path = args .onnx_path
358358
359- model_loader = ModelLoader (
360- args .torch_dir ,
361- args .config_path ,
362- )
359+ model_loader = ModelLoader (args .hf_model_path , args .config_path )
363360
364- if args .torch_dir :
361+ if args .hf_model_path :
365362 # Exporting ONNX from PyTorch model
366- model = model_loader .load_model ()
363+ model = model_loader .load_model (trust_remote_code = args . trust_remote_code )
367364 onnx_dir = args .output_dir + "_raw" if args .save_original else args .output_dir
368365 # Surgeon graph based on precision
369366 raw_onnx_path = f"{ onnx_dir } /model.onnx"
@@ -373,7 +370,7 @@ def main(args):
373370 output_dir = onnx_dir ,
374371 dtype = args .dtype ,
375372 config_path = args .config_path ,
376- torch_dir = args .torch_dir ,
373+ hf_model_path = args .hf_model_path ,
377374 lm_head_precision = args .lm_head ,
378375 dataset_dir = args .dataset_dir ,
379376 wrapper_cls = WrapperModelForCausalLM ,
0 commit comments