Skip to content

Commit 916e1b5

Browse files
ynankanikevalmorabia97
authored andcommitted
[5620217]Mixed-precision handle 8bit layer name matching error (#535)
## What does this PR do? Handle 8bit layer name matching error while running for mixed precision config **Type of change:** Bug fix **Overview:** Due to variations in export methods, the model weight_tensor.name may appear as either an ID or a name. For example: onnx::MatMul_9335 or model.layers.2.attn.qkv_proj.MatMul.weight. Need to adjust the comparison of 8bit_layers with the node names accordingly to handle this variation. ## Testing - Tested using mixed_int4_experiment.py - Executed with the downloaded model from onnx-community/Qwen2.5-1.5B-Instruct ​​​​​​​ - Also tested using the onnruntime-genai exported model from meta-llama/Llama-3.1-8B-Instruct ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> Signed-off-by: unknown <[email protected]>
1 parent 957ce07 commit 916e1b5

File tree

2 files changed

+29
-22
lines changed

2 files changed

+29
-22
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,12 @@ def _find_nodes_from_op_types_to_exclude(graph: Graph, op_types_to_exclude=None)
642642
def _find_int4_quantizable_weights(
643643
graph: onnx.GraphProto,
644644
nodes_to_exclude: list[str],
645-
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int]]:
646-
"""Finds the int4 quantizable weights from the graph."""
645+
) -> list[tuple[onnx.ValueInfoProto, onnx.ValueInfoProto, bool, int, str]]:
646+
"""Finds the int4 quantizable weights from the graph.
647+
648+
Returns:
649+
list of tuples: (act_tensor, weight_tensor, do_transpose, gemm_io_type, node_name)
650+
"""
647651
wa_pack = []
648652
gemm_nodes = [
649653
node
@@ -674,7 +678,8 @@ def _find_int4_quantizable_weights(
674678
attr.name == "transB" and attr.i > 0 for attr in gemm.attribute
675679
)
676680

677-
wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type))
681+
# Include node name for proper matching with layers_8bit_set
682+
wa_pack.append((act_tensor, weight_tensor, do_transpose, gemm_io_type, gemm.name))
678683

679684
return wa_pack
680685

@@ -762,6 +767,8 @@ def get_layer_precision_mapping(
762767
pattern_regexes = [
763768
re.compile(r"^/model/layers\.(\d+)/attn/qkv_proj/MatMul$"),
764769
re.compile(r"^/model/layers\.(\d+)/attn/v_proj/MatMul$"),
770+
re.compile(r"^/model/layers\.(\d+)/self_attn/qkv_proj/MatMul$"),
771+
re.compile(r"^/model/layers\.(\d+)/self_attn/v_proj/MatMul$"),
765772
re.compile(r"^/model/layers\.(\d+)/mlp/down_proj/MatMul$"),
766773
]
767774

@@ -812,12 +819,12 @@ def layer_idx(name):
812819
if (i - rest_start) % 3 == 0:
813820
layers_8bit_set.add(names_sorted[i])
814821
layers_list_8bit = list(layers_8bit_set)
815-
816822
# NEW: Create layer info mapping with precision, block_size, and axis
817823
layer_info = {}
818-
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type) in enumerate(wa_pack):
824+
for i, (act_tensor, weight_tensor, do_transpose, gemm_io_type, node_name) in enumerate(wa_pack):
819825
weight_name = weight_tensor.name
820-
if should_quantize_to_8bit(weight_name, layers_list_8bit):
826+
# Use node_name for matching against layers_8bit patterns
827+
if should_quantize_to_8bit(node_name, layers_list_8bit):
821828
layer_info[weight_name] = {
822829
"precision": 8,
823830
"block_size": -1, # Per-channel for 8-bit

modelopt/onnx/quantization/int4.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -445,11 +445,11 @@ def _clip_search(
445445

446446
def _augment_graph(
447447
graph: onnx.GraphProto,
448-
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int]],
448+
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int, str]],
449449
):
450450
"""Extend graph outputs with MatMuls activation input."""
451451
augmented_outputs = {tensor.name for tensor in graph.output}
452-
for act_tensor, _, _, _ in wa_pack:
452+
for act_tensor, _, _, _, _ in wa_pack:
453453
if act_tensor.name not in augmented_outputs:
454454
graph.output.append(act_tensor)
455455
augmented_outputs.add(act_tensor.name)
@@ -522,7 +522,7 @@ def _quantize_awq_clip(
522522
t = time.time()
523523
alphas = {}
524524
for i in tqdm(range(len(wa_pack)), desc="Running clip search..."):
525-
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i]
525+
act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i]
526526

527527
# First capture all the activation values after calibration data sweep
528528
output_dicts = {}
@@ -554,7 +554,7 @@ def _quantize_awq_clip(
554554
# Compute quantized weights and scales which are needed for DQ nodes
555555
t = time.time()
556556
for i in tqdm(range(len(wa_pack)), desc="Quantizing the weights..."):
557-
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i]
557+
act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i]
558558
gemm_io_type = cast("onnx.TensorProto.DataType", gemm_io_type)
559559

560560
if force_fp16:
@@ -707,7 +707,7 @@ def get_scale(x_max, w_max, alpha):
707707

708708

709709
def run_awq_scale_search_per_node(
710-
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int]],
710+
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int, str]],
711711
augmented_onnx_path,
712712
block_size,
713713
use_zero_point,
@@ -728,7 +728,7 @@ def run_awq_scale_search_per_node(
728728
range(len(wa_pack)),
729729
desc="Running AWQ scale search per node" + tqdm_msg_append_str,
730730
):
731-
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i]
731+
act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i]
732732

733733
output_dicts = {}
734734

@@ -802,7 +802,7 @@ def run_awq_scale_search_per_node(
802802

803803

804804
def get_act_to_weight_map_and_act_to_wa_pack_map(
805-
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int]],
805+
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int, str]],
806806
):
807807
"""Method to return subgraph related maps based on activation-name as key.
808808
@@ -813,7 +813,7 @@ def get_act_to_weight_map_and_act_to_wa_pack_map(
813813
act_to_wa_pack_map = {}
814814
act_to_quant_nodes_weight_shape_map = {}
815815
for i in tqdm(range(len(wa_pack)), desc="Getting activation names maps..."):
816-
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i]
816+
act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i]
817817
# wa_pack index is stored in map to represent quant nodes
818818
act_to_wa_pack_map.setdefault(act_tensor.name, []).append(i)
819819
act_to_quant_nodes_weight_shape_map.setdefault(act_tensor.name, []).append(
@@ -828,7 +828,7 @@ def get_act_to_weight_map_and_act_to_wa_pack_map(
828828

829829

830830
def get_x_w_mean_for_subgraph(
831-
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int]],
831+
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int, str]],
832832
wa_pack_idx_list,
833833
augmented_onnx_path,
834834
x,
@@ -842,7 +842,7 @@ def get_x_w_mean_for_subgraph(
842842

843843
w_concatenated = None
844844
for wa_pack_idx in wa_pack_idx_list:
845-
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[wa_pack_idx]
845+
act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[wa_pack_idx]
846846
w = numpy_helper.to_array(
847847
weight_tensor, base_dir=os.path.dirname(augmented_onnx_path)
848848
).copy()
@@ -880,7 +880,7 @@ def get_x_w_mean_for_subgraph(
880880

881881

882882
def run_awq_scale_search_per_subgraph(
883-
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int]],
883+
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int, str]],
884884
act_to_wa_pack_map,
885885
act_to_quant_nodes_weight_shape_map,
886886
augmented_onnx_path,
@@ -931,7 +931,7 @@ def run_awq_scale_search_per_subgraph(
931931
awq_scale[np.isinf(awq_scale)] = 1
932932
awq_scale[np.isnan(awq_scale)] = 1
933933
for wa_pack_idx in wa_pack_idx_list:
934-
_, weight_tensor, do_transpose, _ = wa_pack[wa_pack_idx]
934+
_, weight_tensor, do_transpose, _, _ = wa_pack[wa_pack_idx]
935935
w = numpy_helper.to_array(
936936
weight_tensor, base_dir=os.path.dirname(augmented_onnx_path)
937937
).copy()
@@ -975,15 +975,15 @@ def run_awq_scale_search_per_subgraph(
975975

976976
def get_parent_child_nodes_map(
977977
graph: onnx.GraphProto,
978-
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int]],
978+
wa_pack: list[tuple[gs.Tensor, gs.Tensor, bool, int, str]],
979979
nodes_to_exclude: list[str],
980980
):
981981
"""Get mapping of parent nodes to their MatMul/Gemm nodes with quantizable weights."""
982982
parent_child_nodes_map = {}
983983
output_name_to_node = get_tensor_producer_nodes(graph)
984984
input_name_to_nodes = get_tensor_consumer_nodes(graph)
985985

986-
for act_tensor, _, _, _ in wa_pack:
986+
for act_tensor, _, _, _, _ in wa_pack:
987987
parent_name = output_name_to_node[act_tensor.name].name
988988
parent_child_nodes_map[parent_name] = []
989989
for node in input_name_to_nodes[act_tensor.name]:
@@ -1069,7 +1069,7 @@ def _quantize_awq_lite(
10691069

10701070
tensor_names_list = []
10711071
for i in tqdm(range(len(wa_pack)), desc="Getting tensor names..."):
1072-
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i]
1072+
act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i]
10731073
tensor_names_list.append(act_tensor.name)
10741074

10751075
for i in tqdm(range(len(inputs)), desc="Caching activations..."):
@@ -1157,7 +1157,7 @@ def _quantize_awq_lite(
11571157
awq_lite[wa_pack_idx].best_scale = mean_awq_scale
11581158

11591159
for i in tqdm(range(len(wa_pack)), desc="Quantizing the weights..."):
1160-
act_tensor, weight_tensor, do_transpose, gemm_io_type = wa_pack[i]
1160+
act_tensor, weight_tensor, do_transpose, gemm_io_type, _ = wa_pack[i]
11611161
gemm_io_type = cast("onnx.TensorProto.DataType", gemm_io_type)
11621162

11631163
if force_fp16:

0 commit comments

Comments
 (0)