Skip to content

Commit 02fb06d

Browse files
authored
[NVBUG_5703882] Add INT4QuantExporter to llm_export.py (NVIDIA#631)
## What does this PR do? **Type of change:** Bug Fix **Overview:** - Added Int4QuantExporter to llm_export.py example - Added E2E integration test for llm_export.py ## Testing ``` python llm_export.py --torch_dir=Qwen/Qwen2-0.5B-Instruct --dtype=fp8 --lm_head=fp16 --output_dir=./qwen2-0.5B-Instruct --calib_size=64 --trust_remote_code ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> --------- Signed-off-by: ajrasane <[email protected]>
1 parent 255eb1a commit 02fb06d

File tree

4 files changed

+47
-4
lines changed

4 files changed

+47
-4
lines changed

examples/onnx_ptq/llm_export.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030
from transformers import AutoConfig, AutoTokenizer
3131

3232
import modelopt
33+
from modelopt.onnx.export import INT4QuantExporter
3334
from modelopt.onnx.llm_export_utils.export_utils import (
3435
ModelLoader,
3536
WrapperModelForCausalLM,
3637
llm_to_onnx,
3738
)
3839
from modelopt.onnx.llm_export_utils.quantization_utils import quantize
3940
from modelopt.onnx.llm_export_utils.surgeon_utils import fold_fp8_qdq_to_dq
40-
from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq, quantize_weights_to_int4
41+
from modelopt.onnx.quantization.qdq_utils import fp4qdq_to_2dq
4142
from modelopt.torch.export import export_hf_checkpoint
4243
from modelopt.torch.quantization.utils import is_quantized_linear
4344

@@ -278,7 +279,7 @@ def time_operation(operation_name):
278279

279280
elif dtype == "int4_awq":
280281
with time_operation("quantizing weights to int4"):
281-
onnx_model = quantize_weights_to_int4(onnx_model)
282+
onnx_model = INT4QuantExporter.process_model(onnx_model)
282283

283284
output_onnx_name = f"{output_dir}/model.onnx"
284285
print(

modelopt/onnx/export/int4_exporter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
3535
graph = onnx_model.graph
3636
value_info_map = {value_info.name: value_info for value_info in graph.value_info}
3737
weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"]
38-
tensor_producer_map = get_tensor_producer_nodes(graph)
38+
tensor_producer_map = get_tensor_producer_nodes(graph, get_initializer_producers=True)
3939

4040
nodes_to_remove = []
4141
for node in weight_dq_nodes:
@@ -126,7 +126,7 @@ def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
126126
graph = onnx_model.graph
127127
initializer_map = {initializer.name: initializer for initializer in graph.initializer}
128128
weight_dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"]
129-
tensor_producer_map = get_tensor_producer_nodes(graph)
129+
tensor_producer_map = get_tensor_producer_nodes(graph, get_initializer_producers=True)
130130

131131
for node in weight_dq_nodes:
132132
weight_name = node.input[0]

modelopt/onnx/quantization/graph_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def get_tensor_from_name(graph: onnx.GraphProto, tensor_name: str) -> onnx.Value
236236

237237
def get_tensor_producer_nodes(
238238
graph: onnx.GraphProto,
239+
get_initializer_producers: bool = False,
239240
) -> dict[str, onnx.NodeProto]:
240241
"""Returns a dictionary of tensor name and their producer node object mapping.
241242
@@ -272,6 +273,10 @@ def get_tensor_producer_nodes(
272273
for output_name in node.output:
273274
tensor_producers[output_name] = node
274275

276+
if get_initializer_producers:
277+
for initializer in graph.initializer:
278+
tensor_producers[initializer.name] = initializer
279+
275280
return tensor_producers
276281

277282

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
import pytest
18+
from _test_utils.examples.run_command import run_onnx_llm_export_command
19+
20+
21+
@pytest.mark.parametrize(
22+
("torch_dir", "dtype", "lm_head", "output_dir", "calib_size"),
23+
[
24+
("Qwen/Qwen2-0.5B-Instruct", "fp16", "fp16", "/tmp/qwen2-0.5b-instruct-fp16", "1"),
25+
("Qwen/Qwen2-0.5B-Instruct", "fp8", "fp16", "/tmp/qwen2-0.5b-instruct-fp8", "1"),
26+
("Qwen/Qwen2-0.5B-Instruct", "int4_awq", "fp16", "/tmp/qwen2-0.5b-instruct-int4_awq", "1"),
27+
("Qwen/Qwen2-0.5B-Instruct", "nvfp4", "fp16", "/tmp/qwen2-0.5b-instruct-nvfp4", "1"),
28+
],
29+
)
30+
def test_llm_export_onnx(torch_dir, dtype, lm_head, output_dir, calib_size):
31+
run_onnx_llm_export_command(
32+
torch_dir=torch_dir,
33+
dtype=dtype,
34+
lm_head=lm_head,
35+
output_dir=output_dir,
36+
calib_size=calib_size,
37+
)

0 commit comments

Comments
 (0)