Skip to content

Commit 7074615

Browse files
Fix onnx_ptq llm_export.py and support Qwen3 (#638)
## What does this PR do? **Type of change:** Bug Fix ## Testing <!-- Mention how have you tested your change if applicable. --> - [x] Tests run locally in docker - [x] Tests enabled in github per-PR CICD --------- Signed-off-by: Keval Morabia <[email protected]>
1 parent 8844a2b commit 7074615

File tree

6 files changed

+52
-69
lines changed

6 files changed

+52
-69
lines changed

.github/workflows/example_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ jobs:
123123
strategy:
124124
fail-fast: false
125125
matrix:
126-
example: [diffusers]
126+
example: [diffusers, onnx_ptq]
127127
uses: ./.github/workflows/_example_tests_runner.yml
128128
secrets: inherit
129129
with:

examples/onnx_ptq/llm_export.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ def llm_arguments():
4747
"""Parse the arguments for the llm export script."""
4848
parser = argparse.ArgumentParser()
4949
parser.add_argument(
50-
"--torch_dir",
50+
"--hf_model_path",
5151
type=str,
52-
help="The folder of HF PyTorch model ckpt or HuggingFace model name/path (e.g., 'Qwen/Qwen2.5-0.5B-Instruct')",
52+
help="The folder of HF PyTorch model ckpt or HuggingFace model name/path (e.g., 'Qwen/Qwen3-0.6B')",
5353
required=False,
5454
)
5555
parser.add_argument(
@@ -110,34 +110,34 @@ def llm_arguments():
110110
def get_config_path(args):
111111
"""
112112
Get config.json file path from the arguments.
113-
The default priority is: config_path > torch_dir/config.json > onnx_path/../config.json
113+
The default priority is: config_path > hf_model_path/config.json > onnx_path/../config.json
114114
"""
115115
if args.config_path and os.path.exists(args.config_path):
116116
return args.config_path
117-
if args.torch_dir:
118-
# Check if torch_dir is a local directory
119-
if os.path.isdir(args.torch_dir):
120-
torch_config = os.path.join(args.torch_dir, "config.json")
117+
if args.hf_model_path:
118+
# Check if hf_model_path is a local directory
119+
if os.path.isdir(args.hf_model_path):
120+
torch_config = os.path.join(args.hf_model_path, "config.json")
121121
if os.path.exists(torch_config):
122122
return torch_config
123123
else:
124124
# For HuggingFace model names, download config temporarily
125125
try:
126126
# Download config from HuggingFace
127127
config = AutoConfig.from_pretrained(
128-
args.torch_dir, trust_remote_code=args.trust_remote_code
128+
args.hf_model_path, trust_remote_code=args.trust_remote_code
129129
)
130130

131131
# Save to temporary file
132132
temp_config_path = os.path.join(
133-
tempfile.gettempdir(), f"config_{args.torch_dir.replace('/', '_')}.json"
133+
tempfile.gettempdir(), f"config_{args.hf_model_path.replace('/', '_')}.json"
134134
)
135135
with open(temp_config_path, "w") as f:
136136
json.dump(config.to_dict(), f, indent=2)
137137

138138
return temp_config_path
139139
except Exception as e:
140-
print(f"Warning: Could not download config for {args.torch_dir}: {e}")
140+
print(f"Warning: Could not download config for {args.hf_model_path}: {e}")
141141

142142
if args.onnx_path:
143143
onnx_config = os.path.join(os.path.dirname(args.onnx_path), "config.json")
@@ -152,7 +152,7 @@ def export_raw_llm(
152152
output_dir,
153153
dtype,
154154
config_path,
155-
torch_dir,
155+
hf_model_path,
156156
lm_head_precision="fp16",
157157
dataset_dir="",
158158
wrapper_cls=WrapperModelForCausalLM,
@@ -167,7 +167,7 @@ def export_raw_llm(
167167
output_dir: str
168168
dtype: str
169169
config_path: str
170-
torch_dir: str, Used for loading tokenizer for quantization
170+
hf_model_path: str, Used for loading tokenizer for quantization
171171
dataset_dir: str, Used for quantization
172172
wrapper_cls: class, Used for wrapping the model
173173
extra_inputs: dict, Used for extra inputs
@@ -187,11 +187,11 @@ def export_raw_llm(
187187
# Need to quantize model to fp8, int4_awq or nvfp4
188188
if dtype in ["fp8", "int4_awq", "nvfp4"]:
189189
tokenizer = AutoTokenizer.from_pretrained(
190-
torch_dir, trust_remote_code=args.trust_remote_code
190+
hf_model_path, trust_remote_code=args.trust_remote_code
191191
)
192-
# Only check for local modelopt_state if torch_dir is a local directory
193-
if os.path.isdir(torch_dir):
194-
modelopt_state = os.path.join(torch_dir, "modelopt_state.pth")
192+
# Only check for local modelopt_state if hf_model_path is a local directory
193+
if os.path.isdir(hf_model_path):
194+
modelopt_state = os.path.join(hf_model_path, "modelopt_state.pth")
195195
model_needs_quantization = not os.path.exists(modelopt_state)
196196
else:
197197
# For HuggingFace model names, always quantize as we can't have local state files
@@ -345,8 +345,8 @@ def get_modelopt_version():
345345

346346
def main(args):
347347
"""Main function to export the LLM model to ONNX."""
348-
assert args.torch_dir or args.onnx_path, (
349-
"You need to provide either --torch_dir or --onnx_path to process the export script."
348+
assert args.hf_model_path or args.onnx_path, (
349+
"You need to provide either --hf_model_path or --onnx_path to process the export script."
350350
)
351351
start_time = time.time()
352352

@@ -356,14 +356,11 @@ def main(args):
356356
if args.onnx_path:
357357
raw_onnx_path = args.onnx_path
358358

359-
model_loader = ModelLoader(
360-
args.torch_dir,
361-
args.config_path,
362-
)
359+
model_loader = ModelLoader(args.hf_model_path, args.config_path)
363360

364-
if args.torch_dir:
361+
if args.hf_model_path:
365362
# Exporting ONNX from PyTorch model
366-
model = model_loader.load_model()
363+
model = model_loader.load_model(trust_remote_code=args.trust_remote_code)
367364
onnx_dir = args.output_dir + "_raw" if args.save_original else args.output_dir
368365
# Surgeon graph based on precision
369366
raw_onnx_path = f"{onnx_dir}/model.onnx"
@@ -373,7 +370,7 @@ def main(args):
373370
output_dir=onnx_dir,
374371
dtype=args.dtype,
375372
config_path=args.config_path,
376-
torch_dir=args.torch_dir,
373+
hf_model_path=args.hf_model_path,
377374
lm_head_precision=args.lm_head,
378375
dataset_dir=args.dataset_dir,
379376
wrapper_cls=WrapperModelForCausalLM,

modelopt/onnx/llm_export_utils/export_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from enum import Enum
2222

2323
import torch
24-
from transformers import DynamicCache
24+
from transformers import AutoModelForCausalLM, DynamicCache
2525

2626

2727
class RopeType(Enum):
@@ -36,10 +36,10 @@ class RopeType(Enum):
3636
class ModelLoader:
3737
"""A class to handle HuggingFace model loading and configuration."""
3838

39-
def __init__(self, torch_dir, config_path):
39+
def __init__(self, hf_model_path: str, config_path: str):
4040
"""Initialize the ModelLoader."""
4141
self.config_path = config_path
42-
self.torch_dir = torch_dir
42+
self.hf_model_path = hf_model_path
4343
self.model_type = self.get_model_type()
4444
self.hf_model = None
4545
self.rope_type = RopeType.K_ROPE_ROTATE_NEOX
@@ -49,16 +49,14 @@ def get_model_type(self):
4949
with open(self.config_path) as f:
5050
return json.load(f).get("model_type")
5151

52-
def load_model(self):
52+
def load_model(self, trust_remote_code: bool = False) -> AutoModelForCausalLM:
5353
"""Load HuggingFace model based on model type."""
54-
print(f"Loading HF model from {self.torch_dir} with model type {self.model_type}")
55-
from transformers import AutoModelForCausalLM
56-
54+
print(f"Loading HF model from {self.hf_model_path} with model type {self.model_type}")
5755
self.hf_model = AutoModelForCausalLM.from_pretrained(
58-
self.torch_dir, torch_dtype=torch.float16, trust_remote_code=True
56+
self.hf_model_path, torch_dtype=torch.float16, trust_remote_code=trust_remote_code
5957
)
6058

61-
return self.hf_model.eval().cuda()
59+
return self.hf_model.eval().cuda() # type: ignore[attr-defined]
6260

6361
def get_rope_type(self):
6462
"""Get rope type."""
@@ -78,13 +76,14 @@ def __init__(self, model):
7876
self.lm_head = model.lm_head
7977
self.config = model.config
8078

81-
def forward(
82-
self,
83-
input_ids,
84-
past_key_values,
85-
):
79+
def forward(self, input_ids: torch.Tensor | None, past_key_values: tuple):
8680
"""Forward pass."""
87-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
81+
# Convert tuple cache to DynamicCache for models that require it (e.g., Qwen3)
82+
cache = DynamicCache(config=self.config)
83+
cache.key_cache = [kv[0] for kv in past_key_values]
84+
cache.value_cache = [kv[1] for kv in past_key_values]
85+
past_key_values = cache
86+
8887
outputs = self.model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
8988
hidden_states = outputs[0]
9089
past_key_values = outputs.past_key_values.to_legacy_cache()
@@ -159,4 +158,5 @@ def torch_to_onnx(model, inputs, onnx_dir, onnx_name, input_names, output_names,
159158
dynamic_axes=dynamic_axes,
160159
opset_version=19,
161160
do_constant_folding=True,
161+
dynamo=False,
162162
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ disable_error_code = ["attr-defined"]
133133
# Default additional options
134134
# Show a short test summary info for all except passed tests with -ra flag
135135
# print execution time for 20 slowest tests and generate coverage reports
136-
addopts = "-ra --instafail --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers"
136+
addopts = "-v -ra --instafail --cov-report=term-missing --cov-report=html --cov-report=xml:coverage.xml --cov-config=pyproject.toml --durations=20 --strict-markers"
137137
pythonpath = ["tests/"]
138138
markers = [
139139
"manual: Only run when --run-manual is given",

tests/_test_utils/examples/run_command.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,6 @@ def run_command_in_background(
6262
return process
6363

6464

65-
def run_onnx_llm_export_command(
66-
*, torch_dir: str, dtype: str, lm_head: str, output_dir: str, calib_size: str, **kwargs
67-
):
68-
kwargs.update(
69-
{
70-
"torch_dir": torch_dir,
71-
"dtype": dtype,
72-
"lm_head": lm_head,
73-
"output_dir": output_dir,
74-
"calib_size": calib_size,
75-
}
76-
)
77-
cmd_parts = extend_cmd_parts(["python", "llm_export.py"], **kwargs)
78-
run_example_command(cmd_parts, "onnx_ptq")
79-
80-
8165
def run_llm_ptq_command(*, model: str, quant: str, **kwargs):
8266
kwargs.update({"model": model, "quant": quant})
8367
kwargs.setdefault("tasks", "quant")

tests/examples/onnx_ptq/test_llm_export.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,25 @@
1515

1616

1717
import pytest
18-
from _test_utils.examples.run_command import run_onnx_llm_export_command
18+
from _test_utils.examples.run_command import extend_cmd_parts, run_example_command
1919

2020

2121
@pytest.mark.parametrize(
22-
("torch_dir", "dtype", "lm_head", "output_dir", "calib_size"),
22+
("hf_model_path", "dtype", "lm_head"),
2323
[
24-
("Qwen/Qwen2-0.5B-Instruct", "fp16", "fp16", "/tmp/qwen2-0.5b-instruct-fp16", "1"),
25-
("Qwen/Qwen2-0.5B-Instruct", "fp8", "fp16", "/tmp/qwen2-0.5b-instruct-fp8", "1"),
26-
("Qwen/Qwen2-0.5B-Instruct", "int4_awq", "fp16", "/tmp/qwen2-0.5b-instruct-int4_awq", "1"),
27-
("Qwen/Qwen2-0.5B-Instruct", "nvfp4", "fp16", "/tmp/qwen2-0.5b-instruct-nvfp4", "1"),
24+
("Qwen/Qwen2-0.5B-Instruct", "fp16", "fp16"),
25+
("Qwen/Qwen2-0.5B-Instruct", "fp8", "fp16"),
26+
("Qwen/Qwen3-0.6B", "int4_awq", "fp16"),
27+
("Qwen/Qwen3-0.6B", "nvfp4", "fp16"),
2828
],
2929
)
30-
def test_llm_export_onnx(torch_dir, dtype, lm_head, output_dir, calib_size):
31-
run_onnx_llm_export_command(
32-
torch_dir=torch_dir,
30+
def test_llm_export_onnx(tmp_path, hf_model_path, dtype, lm_head):
31+
cmd_parts = extend_cmd_parts(
32+
["python", "llm_export.py"],
33+
hf_model_path=hf_model_path,
3334
dtype=dtype,
3435
lm_head=lm_head,
35-
output_dir=output_dir,
36-
calib_size=calib_size,
36+
output_dir=str(tmp_path),
37+
calib_size=1,
3738
)
39+
run_example_command(cmd_parts, "onnx_ptq")

0 commit comments

Comments
 (0)