From aa0582e3c82026630f9912b55d5165f2d7ff6922 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Thu, 26 Feb 2026 11:19:30 -0800 Subject: [PATCH 01/23] Add Qwen3-VL / Qwen2.5-VL ONNX export support - graph_surgeries.py: add QwenVL-specific graph surgery passes for vision embedding merge and positional encoding fixup - rtn_quantization.py: extend RTN quantization for multimodal models, handle vision encoder exclusion patterns - cast_chain_elimination.py: new pass to eliminate redundant Cast chains in Dynamo-exported models (fp32->fp16->fp32 patterns) - olive_config.json: register new passes --- olive/olive_config.json | 8 ++ olive/passes/onnx/cast_chain_elimination.py | 114 ++++++++++++++++++++ olive/passes/onnx/graph_surgeries.py | 83 ++++++++++++++ olive/passes/onnx/rtn_quantization.py | 69 ++++++++++-- 4 files changed, 267 insertions(+), 7 deletions(-) create mode 100644 olive/passes/onnx/cast_chain_elimination.py diff --git a/olive/olive_config.json b/olive/olive_config.json index 77debc1818..4afbfcdc48 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -315,6 +315,14 @@ "supported_quantization_encodings": [ ], "extra_dependencies": [ "onnxoptimizer", "onnxscript" ] }, + "OnnxCastChainElimination": { + "module_path": "olive.passes.onnx.cast_chain_elimination.OnnxCastChainElimination", + "supported_providers": [ "*" ], + "supported_accelerators": [ "*" ], + "supported_precisions": [ "*" ], + "supported_algorithms": [ ], + "supported_quantization_encodings": [ ] + }, "OnnxQuantization": { "module_path": "olive.passes.onnx.quantization.OnnxQuantization", "supported_providers": [ "CPUExecutionProvider" ], diff --git a/olive/passes/onnx/cast_chain_elimination.py b/olive/passes/onnx/cast_chain_elimination.py new file mode 100644 index 0000000000..7eb65ceead --- /dev/null +++ b/olive/passes/onnx/cast_chain_elimination.py @@ -0,0 +1,114 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""ORT-based Cast chain elimination and com.microsoft opset fixup. + +Dynamo-exported ONNX models often contain redundant Cast chains +(e.g. fp32→fp16→fp32) that double memory traffic and slow inference. +ORT has a graph optimization for this, but it is disabled by default +behind the ``session.enable_cast_chain_elimination`` session config. + +Additionally, Olive ``GraphSurgeries`` may insert ``com.microsoft`` +operators (such as ``LoopMHA``) without registering the custom opset +on every ONNX function scope, causing downstream failures. + +This pass: +1. Ensures ``com.microsoft`` opset version 1 is declared on the model + *and* on every ONNX function scope. +2. Runs ORT ``ORT_ENABLE_BASIC`` optimization with Cast chain + elimination explicitly enabled to produce a cleaned-up model. +""" +import logging +from pathlib import Path + +import onnx +from onnx import helper + +from olive.hardware.accelerator import AcceleratorSpec +from olive.model import ONNXModelHandler +from olive.model.utils import resolve_onnx_path +from olive.passes import Pass +from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +logger = logging.getLogger(__name__) + + +def _ensure_com_microsoft_opset(model: onnx.ModelProto): + """Ensure com.microsoft opset v1 is declared at model and function level.""" + existing = {op.domain for op in model.opset_import} + if "com.microsoft" not in existing: + model.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + for func in model.functions: + func_domains = {op.domain for op in func.opset_import} + if "com.microsoft" not in func_domains: + func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + + +class OnnxCastChainElimination(Pass): + """Fix com.microsoft opset declarations and eliminate redundant Cast chains. + + This pass first ensures the ``com.microsoft`` opset version 1 is + registered on the model graph and every ONNX function scope (needed + after ``GraphSurgeries`` inserts custom ops into dynamo-exported + models). It then runs ORT basic graph optimization with the + ``session.enable_cast_chain_elimination`` flag enabled to collapse + consecutive Cast operators that cancel out. + """ + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: + return { + "fix_opset": PassConfigParam( + type_=bool, + default_value=True, + description="Ensure com.microsoft opset v1 is declared on all scopes.", + ), + "enable_cast_chain_elimination": PassConfigParam( + type_=bool, + default_value=True, + description="Run ORT basic optimization with Cast chain elimination.", + ), + **get_external_data_config(), + } + + def _run_for_config( + self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str + ) -> ONNXModelHandler: + output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) + + onnx_model = model.load_model() + + # Step 1: Opset fixup + if config.fix_opset: + _ensure_com_microsoft_opset(onnx_model) + + # Step 2: Cast chain elimination via ORT session optimization + if config.enable_cast_chain_elimination: + import tempfile + + import onnxruntime as ort + + # ORT needs the patched model on disk to optimise it. + # Large models (>2 GB) must use external data to avoid protobuf limits. + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = str(Path(tmp_dir) / "model.onnx") + onnx.save_model( + onnx_model, + tmp_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location="model.onnx.data", + convert_attribute=True, + ) + + sess_options = ort.SessionOptions() + sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC + sess_options.optimized_model_filepath = tmp_path + sess_options.add_session_config_entry("session.enable_cast_chain_elimination", "1") + ort.InferenceSession(tmp_path, sess_options, providers=["CPUExecutionProvider"]) + + onnx_model = onnx.load(tmp_path, load_external_data=True) + + return model_proto_to_olive_model(onnx_model, output_model_path, config) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 7d1c41119f..f0b06bc2a0 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -1233,6 +1233,89 @@ def __call__(self, model: ModelProto): return dag.model +class GemmToMatMulAdd(ProtoSurgeon): + """Replace Gemm with MatMul (+ Add) for INT4 quantization compatibility. + + The INT4 RTN quantizer only recognizes MatMul nodes. This surgeon converts + Gemm nodes back to MatMul+Add so that the weight matrices become eligible + for block-wise quantization. + + Handles transB by transposing constant weights in-place or inserting a + Transpose node for non-constant weights. Skips Gemm nodes whose alpha/beta + are not 1.0 or whose transA is set. + """ + + def __call__(self, model: ModelProto): + from onnx import helper, numpy_helper + + graph = model.graph + initializer_map = {init.name: init for init in graph.initializer} + nodes_to_remove = [] + nodes_to_add = [] + + for node in graph.node: + if node.op_type != "Gemm": + continue + + alpha = beta = 1.0 + transA = transB = 0 + for attr in node.attribute: + if attr.name == "alpha": + alpha = attr.f + elif attr.name == "beta": + beta = attr.f + elif attr.name == "transA": + transA = attr.i + elif attr.name == "transB": + transB = attr.i + + if alpha != 1.0 or beta != 1.0 or transA != 0: + continue + + A, B = node.input[0], node.input[1] + C = node.input[2] if len(node.input) > 2 else None + Y = node.output[0] + + if transB: + if B in initializer_map: + init = initializer_map[B] + w_t = numpy_helper.to_array(init).T.copy() + new_init = numpy_helper.from_array(w_t, name=B) + for i, existing in enumerate(graph.initializer): + if existing.name == B: + graph.initializer[i].CopyFrom(new_init) + break + matmul_B = B + else: + transpose_out = f"{node.name}_transpose_B" + nodes_to_add.append( + helper.make_node("Transpose", [B], [transpose_out], name=f"{node.name}_Transpose", perm=[1, 0]) + ) + matmul_B = transpose_out + else: + matmul_B = B + + if C: + matmul_out = f"{node.name}_matmul_out" + nodes_to_add.append( + helper.make_node("MatMul", [A, matmul_B], [matmul_out], name=f"{node.name}_MatMul") + ) + nodes_to_add.append(helper.make_node("Add", [matmul_out, C], [Y], name=f"{node.name}_Add")) + else: + nodes_to_add.append(helper.make_node("MatMul", [A, matmul_B], [Y], name=f"{node.name}_MatMul")) + + nodes_to_remove.append(node) + + for node in nodes_to_remove: + graph.node.remove(node) + graph.node.extend(nodes_to_add) + + if nodes_to_remove: + logger.debug("Replaced %d Gemm nodes with MatMul + Add nodes", len(nodes_to_remove)) + + return model + + class RemoveRopeMultiCache(ProtoSurgeon): """Remove the multi rope cache from the model.""" diff --git a/olive/passes/onnx/rtn_quantization.py b/olive/passes/onnx/rtn_quantization.py index d42bbcc229..47a04b82ab 100644 --- a/olive/passes/onnx/rtn_quantization.py +++ b/olive/passes/onnx/rtn_quantization.py @@ -104,6 +104,10 @@ def _quantize_model( nodes_to_exclude = nodes_to_exclude or [] nodes_to_include = nodes_to_include or [] + # Track initializer names already registered across all nodes + # to handle shared weights (e.g., pos_embed used by multiple Gather nodes). + globally_registered = {} + ir_model.graph.sort() for node in ir_model.graph.all_nodes(): node_name = node.name @@ -115,14 +119,19 @@ def _quantize_model( elif node.op_type in (str(OpType.MatMul), str(OpType.Gather)) and ( node_name in nodes_to_include or not nodes_to_include ): - if (node.op_type == OpType.MatMul and not node.inputs[1].is_initializer()) or ( - node.op_type == OpType.Gather and not node.inputs[0].is_initializer() - ): + # MatMul weight is inputs[1], Gather weight (embedding table) is inputs[0] + weight_idx = 1 if node.op_type == str(OpType.MatMul) else 0 + if not node.inputs[weight_idx].is_initializer(): logger.debug("skip to quantize %s as it has no initializer", node_name) continue - if node.op_type == OpType.Gather and bits != 4: - logger.warning("Gather only supports 4-bit quantization.") + if node.op_type == str(OpType.Gather) and bits != 4: + logger.warning( + "Gather quantization is currently only implemented for 4-bit. " + "(GatherBlockQuantized op supports 4 or 8 bits, but the 8-bit path is not yet " + "implemented here.) Skip node %s.", + node_name, + ) continue quantized_node, initializer_graph = self._quantize( @@ -133,9 +142,16 @@ def _quantize_model( registered = {} for input_value in quantized_node.inputs: if input_value.const_value is not None: - if input_value.name not in registered: + if input_value.name in globally_registered: + # Already registered by a previous node (shared weight), + # replace with the existing initializer. + ir.convenience.replace_all_uses_with( + input_value, globally_registered[input_value.name] + ) + elif input_value.name not in registered: initializer_graph.register_initializer(input_value) registered[input_value.name] = input_value + globally_registered[input_value.name] = input_value else: logger.debug( "Found duplicated initializer %s, replace all uses with the first one.", @@ -149,6 +165,25 @@ def _quantize_model( else: logger.debug("skip to quantize %s ...", node_name) + # Remove initializers that are no longer referenced by any node. + # After quantization, the original FP32 weight initializers become orphaned + # because the old MatMul/Gather nodes were replaced with MatMulNBits/GatherBlockQuantized + # nodes that reference new INT4 weight initializers instead. + used_names: set[str] = set() + for node in ir_model.graph.all_nodes(): + for inp in node.inputs: + if inp is not None and inp.name: + used_names.add(inp.name) + for out in ir_model.graph.outputs: + if out is not None and out.name: + used_names.add(out.name) + + unused = [name for name in ir_model.graph.initializers if name not in used_names] + for name in unused: + del ir_model.graph.initializers[name] + if unused: + logger.info("Removed %d unused initializers after quantization.", len(unused)) + def _quantize( self, node: ir.Node, bits: int, block_size: int, axis: int, accuracy_level: AccuracyLevel, is_symmetric: bool ) -> tuple[ir.Node, ir.Graph]: @@ -221,7 +256,10 @@ def _quantize_gather( node_initializer = node.inputs[0] data_ndarray = node_initializer.const_value.numpy() data_rank = len(data_ndarray.shape) - quantize_axis = axis + + # ORT GatherBlockQuantized requires quantize_axis == last dimension + # when data is packed as uint8 (two int4 values per byte). + quantize_axis = data_rank - 1 assert -data_rank <= quantize_axis < data_rank, "Invalid quantize axis for Gather node." assert block_size >= 16, "Block size must be greater than or equal to 16." @@ -250,6 +288,7 @@ def _quantize_gather( "gather_axis": gather_axis, "quantize_axis": quantize_axis, "block_size": block_size, + "bits": bits, } node.outputs[0].name = node.outputs[0].name + f"_Q{bits}" @@ -385,9 +424,25 @@ def _quantize_ndarray( zero_point_int8[:, j : (j + 1), :] = zero_point_slice # pack int8 to int4 + # ORT GatherBlockQuantized uses unsigned int4 representation [0, 15] + # where zero_point=8 is implied for symmetric quantization. + # Convert signed int8 [-8, 7] to unsigned [0, 15] by adding 8. + if is_symmetric: + quant_data_int8 = (quant_data_int8.astype(np.int16) + 8).astype(np.uint8) quant_data_int4 = self._pack_int8_to_int4(quant_data_int8) zero_point_int4 = None if not is_symmetric: zero_point_int4 = self._pack_int8_to_int4(zero_point_int8) scales = scales.reshape(scales_shape) + + # Reshape packed data to match original rank (GatherBlockQuantized requires rank > 1). + # Two int4 values are packed per uint8, so the quantize_axis dimension is halved. + packed_shape = list(data.shape) + packed_shape[quantize_axis] = (packed_shape[quantize_axis] + 1) // 2 + quant_data_int4 = quant_data_int4.reshape(packed_shape) + if zero_point_int4 is not None: + zp_shape = list(scales_shape) + zp_shape[quantize_axis] = (zp_shape[quantize_axis] + 1) // 2 + zero_point_int4 = zero_point_int4.reshape(zp_shape) + return quant_data_int4, scales, zero_point_int4 From 514362d3bd920be1360aeb0ba320185c7b256db6 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Thu, 26 Feb 2026 17:30:22 -0800 Subject: [PATCH 02/23] Fix ModelBuilder sys.path for ort-genai builders package import --- olive/passes/onnx/model_builder.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 5f8dceabc9..2ad3fec3f0 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -213,6 +213,16 @@ def _run_for_config( output_model_path: str, ) -> ONNXModelHandler: try: + # ort-genai's builder.py uses bare `from builders import ...` (designed + # for script execution). When imported as a module we must add the + # package's models/ directory to sys.path so that the sub-package is + # found correctly. + import sys + import onnxruntime_genai as _og + import os as _os + _models_dir = _os.path.join(_os.path.dirname(_og.__file__), "models") + if _models_dir not in sys.path: + sys.path.insert(0, _models_dir) from onnxruntime_genai.models.builder import create_model except ImportError: raise ImportError( From cb1987b9f897373608c56a73c985c3aff39b58e0 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Thu, 26 Feb 2026 17:41:35 -0800 Subject: [PATCH 03/23] Expose real ModelBuilder import error for debugging --- olive/passes/onnx/model_builder.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 2ad3fec3f0..cdb31bcff6 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -224,12 +224,13 @@ def _run_for_config( if _models_dir not in sys.path: sys.path.insert(0, _models_dir) from onnxruntime_genai.models.builder import create_model - except ImportError: + except Exception as _import_err: raise ImportError( - "onnxruntime-genai package is required to run ModelBuilder pass. Please install the package" - " corresponding to your onnxruntime installation using pip. cpu: onnxruntime-genai, cuda:" - " onnxruntime-genai-cuda, directml: onnxruntime-genai-directml" - ) from None + f"onnxruntime-genai package is required to run ModelBuilder pass." + f" Underlying error: {type(_import_err).__name__}: {_import_err}." + " Please install the package corresponding to your onnxruntime installation using pip." + " cpu: onnxruntime-genai, cuda: onnxruntime-genai-cuda, directml: onnxruntime-genai-directml" + ) from _import_err self.maybe_patch_quant() precision = config.precision From 2c2269ea487d50c881ed71f2f82c6e594be083a5 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Thu, 26 Feb 2026 17:46:51 -0800 Subject: [PATCH 04/23] Clean up ModelBuilder import fix (expose chain, not debug print) --- olive/passes/onnx/model_builder.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index cdb31bcff6..8a0e1be54b 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -224,12 +224,11 @@ def _run_for_config( if _models_dir not in sys.path: sys.path.insert(0, _models_dir) from onnxruntime_genai.models.builder import create_model - except Exception as _import_err: + except ImportError as _import_err: raise ImportError( - f"onnxruntime-genai package is required to run ModelBuilder pass." - f" Underlying error: {type(_import_err).__name__}: {_import_err}." - " Please install the package corresponding to your onnxruntime installation using pip." - " cpu: onnxruntime-genai, cuda: onnxruntime-genai-cuda, directml: onnxruntime-genai-directml" + "onnxruntime-genai package is required to run ModelBuilder pass. Please install the package" + " corresponding to your onnxruntime installation using pip. cpu: onnxruntime-genai, cuda:" + " onnxruntime-genai-cuda, directml: onnxruntime-genai-directml" ) from _import_err self.maybe_patch_quant() From e77864f0df748b32b3d50596c9a9627892c79228 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Fri, 27 Feb 2026 14:01:10 -0800 Subject: [PATCH 05/23] Remove sys.path hack for onnxruntime-genai builder import --- olive/passes/onnx/model_builder.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index 8a0e1be54b..5f8dceabc9 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -213,23 +213,13 @@ def _run_for_config( output_model_path: str, ) -> ONNXModelHandler: try: - # ort-genai's builder.py uses bare `from builders import ...` (designed - # for script execution). When imported as a module we must add the - # package's models/ directory to sys.path so that the sub-package is - # found correctly. - import sys - import onnxruntime_genai as _og - import os as _os - _models_dir = _os.path.join(_os.path.dirname(_og.__file__), "models") - if _models_dir not in sys.path: - sys.path.insert(0, _models_dir) from onnxruntime_genai.models.builder import create_model - except ImportError as _import_err: + except ImportError: raise ImportError( "onnxruntime-genai package is required to run ModelBuilder pass. Please install the package" " corresponding to your onnxruntime installation using pip. cpu: onnxruntime-genai, cuda:" " onnxruntime-genai-cuda, directml: onnxruntime-genai-directml" - ) from _import_err + ) from None self.maybe_patch_quant() precision = config.precision From af5983f7b9427ef1dedbc9d9ab333d1b3acffee6 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Tue, 3 Mar 2026 17:33:14 +0000 Subject: [PATCH 06/23] Add 8-bit Gather quantization support, ByteSize crash fix, and graph surgery passes - rtn_quantization.py: Parameterize bits through quantization methods to support 8-bit Gather - common.py: Fix ByteSize() crash for >2GB models, fix FOLDED_FROM_KEY import - graph_surgeries.py: Add ReciprocalMulToDiv, DeduplicateSubgraphInitializers, DeduplicateNodes --- olive/passes/onnx/common.py | 15 ++- olive/passes/onnx/graph_surgeries.py | 160 ++++++++++++++++++++++++++ olive/passes/onnx/rtn_quantization.py | 93 +++++++++------ 3 files changed, 230 insertions(+), 38 deletions(-) diff --git a/olive/passes/onnx/common.py b/olive/passes/onnx/common.py index 6711fee49a..4b6fcd0ecc 100644 --- a/olive/passes/onnx/common.py +++ b/olive/passes/onnx/common.py @@ -12,7 +12,13 @@ import onnx from onnx import external_data_helper from onnxscript import ir -from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY + +# TODO: Remove try/except once onnxscript >= 0.2.0 (which exports FOLDED_FROM_KEY) is the minimum required version. +# After that, replace with: from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY +try: + from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY +except ImportError: + FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.constant_folding.folded_from" from olive.common.utils import StrEnumBase, hardlink_copy_file from olive.model import CompositeModelHandler, ONNXModelHandler @@ -130,7 +136,12 @@ def model_proto_to_file( output_dir = output_path.parent output_dir.mkdir(parents=True, exist_ok=True) - model_size = model.ByteSize() + model_size = 0 + try: + model_size = model.ByteSize() + except Exception: + # ByteSize() fails for models >2GB due to protobuf serialization limits + pass # model size for large models might be negative (overflow?) on Windows # see https://github.com/onnx/onnx/issues/5861 if not save_as_external_data and (model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF): diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index f0b06bc2a0..6fa4fb0355 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -2124,6 +2124,166 @@ def equal_weights(self, dag: OnnxDAG, init0: str, init1: str, transpose: bool = return np.array_equal(arr0.ravel(), arr1.ravel()) +class ReciprocalMulToDiv(ProtoSurgeon): + """Replace Reciprocal(x) * a with Div(a, x). + + Before: + [x] --> Reciprocal --> Mul --> [out] + ^ + | + [a] + + After: + [a] --> Div --> [out] + ^ + | + [x] + + Why this is needed: + PyTorch's ``torch.rsqrt()`` (used by Qwen2.5-VL's ``Qwen2RMSNorm``) decomposes to + ``Sqrt -> Reciprocal -> Mul`` in ONNX. ORT's ``SimplifiedLayerNormFusion`` only + matches the pattern ``Pow -> ReduceMean -> Add -> Sqrt -> Div -> Mul`` — it does + **not** recognize the ``Reciprocal -> Mul`` variant (confirmed on ORT main as of + 2025-06). This pass canonicalizes the graph so that the fusion fires, replacing + decomposed RMSNorm with a single ``SimplifiedLayerNormalization`` op. + + When to use: + Run **before** ``OrtTransformersOptimization`` on models whose normalization layers + export ``rsqrt`` as ``Reciprocal`` (e.g. HuggingFace models using ``torch.rsqrt``). + """ + + def __call__(self, model: ModelProto): + modified = 0 + nodes_to_remove = [] + + for node in model.graph.node: + if node.op_type != "Reciprocal": + continue + + recip_input = node.input[0] # x + recip_output = node.output[0] + + # Find Mul consumers of this Reciprocal + mul_nodes = [ + n for n in model.graph.node + if n.op_type == "Mul" and recip_output in n.input + ] + + for mul_node in mul_nodes: + # Identify the other operand (not from Reciprocal) + if mul_node.input[0] == recip_output: + other_input = mul_node.input[1] + else: + other_input = mul_node.input[0] + + # Convert Mul(a, Reciprocal(x)) to Div(a, x) in-place + mul_node.op_type = "Div" + mul_node.input[0] = other_input + mul_node.input[1] = recip_input + if mul_node.name: + mul_node.name = self.create_new_name(mul_node.name, "Mul", "Div") + modified += 1 + + # If no more consumers of Reciprocal output, mark for removal + remaining = [n for n in model.graph.node if n != node and recip_output in n.input] + if not remaining: + nodes_to_remove.append(node) + + for node in nodes_to_remove: + model.graph.node.remove(node) + + if modified > 0: + logger.debug("Replaced %d Reciprocal+Mul patterns with Div", modified) + + return model + + +class DeduplicateSubgraphInitializers(ProtoSurgeon): + """Remove duplicate initializers in Loop / If / Scan subgraphs. + + Why this is needed: + ORT's graph optimizer (constant folding, shape inference, etc.) may copy + initializers into subgraphs that already contain them, creating entries with + identical names. ORT's ``ConstantSharing`` pass explicitly skips subgraph + usage (``constant_sharing.cc``: "If usage is from subgraph, skip it now"), + so these duplicates are never cleaned up. Duplicate initializers violate + the ONNX spec's unique-name requirement and can cause validation failures + or silent data corruption. + + What it does: + For every ``Loop`` / ``If`` / ``Scan`` subgraph, keeps the first initializer + with a given name and removes all subsequent duplicates. + + When to use: + Run **after** ``OrtTransformersOptimization`` (which introduces the duplicates) + and **before** any pass that serializes or validates the model. + """ + + def __call__(self, model: ModelProto): + removed = 0 + for node in model.graph.node: + for attr in node.attribute: + if attr.g and attr.g.initializer: + seen = set() + to_remove = [] + for init in attr.g.initializer: + if init.name in seen: + to_remove.append(init) + else: + seen.add(init.name) + for init in to_remove: + attr.g.initializer.remove(init) + removed += 1 + if removed > 0: + logger.debug("Removed %d duplicate subgraph initializers", removed) + return model + + +class DeduplicateNodes(ProtoSurgeon): + """Remove nodes whose output tensors are already produced by an earlier node. + + Before (invalid — two nodes define the same tensor ``/Cast_output_0``): + NodeA --> Cast --> /Cast_output_0 + NodeB --> Cast --> /Cast_output_0 (duplicate, removed) + + After: + NodeA --> Cast --> /Cast_output_0 + + Why this is needed: + ORT's ``convert_float_to_float16`` (``float16.py``) may insert identical + ``Cast`` nodes in parallel branches that each declare the same output tensor + name. The ONNX spec requires every tensor to have a unique producer; loading + a model with duplicate producers causes ``onnxruntime.InferenceSession`` to + fail with a duplicate-definition error. + + What it does: + Scans nodes in graph order and records each output tensor name. If a later + node produces a tensor name that was already seen, the entire node is removed. + + When to use: + Run **after** ``OnnxFloatToFloat16`` as a cleanup step. + """ + + def __call__(self, model: ModelProto): + output_seen: set[str] = set() + indices_to_remove: list[int] = [] + for i, node in enumerate(model.graph.node): + dup = False + for o in node.output: + if o and o in output_seen: + dup = True + break + if o: + output_seen.add(o) + if dup: + indices_to_remove.append(i) + for i in reversed(indices_to_remove): + del model.graph.node[i] + if indices_to_remove: + logger.debug("Removed %d duplicate nodes", len(indices_to_remove)) + return model + + class PackedAttentionToLoopMHA(Surgeon): """Replace custom::PackedAttention with a loop calling com.microsoft::MultiHeadAttention. diff --git a/olive/passes/onnx/rtn_quantization.py b/olive/passes/onnx/rtn_quantization.py index 47a04b82ab..78aa3adfcf 100644 --- a/olive/passes/onnx/rtn_quantization.py +++ b/olive/passes/onnx/rtn_quantization.py @@ -125,12 +125,12 @@ def _quantize_model( logger.debug("skip to quantize %s as it has no initializer", node_name) continue - if node.op_type == str(OpType.Gather) and bits != 4: + if node.op_type == str(OpType.Gather) and bits not in (4, 8): logger.warning( - "Gather quantization is currently only implemented for 4-bit. " - "(GatherBlockQuantized op supports 4 or 8 bits, but the 8-bit path is not yet " - "implemented here.) Skip node %s.", + "Gather quantization is only implemented for 4-bit and 8-bit. " + "Skip node %s (bits=%d).", node_name, + bits, ) continue @@ -267,7 +267,7 @@ def _quantize_gather( quantize_axis = (quantize_axis + data_rank) % data_rank quantized_data, scales, zero_point = self._quantize_ndarray( - data_ndarray, quantize_axis, block_size, is_symmetric + data_ndarray, quantize_axis, block_size, is_symmetric, bits ) quantized_data_tensorproto = ir.Value( @@ -345,24 +345,28 @@ def _qbits_block_quant( return (packed, scales, zero_point) @staticmethod - def _quant_slice_symmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def _quant_slice_symmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray, np.ndarray]: + qmin = -(1 << (bits - 1)) # -8 for 4-bit, -128 for 8-bit + qmax = (1 << (bits - 1)) - 1 # 7 for 4-bit, 127 for 8-bit max_val = np.max(data, axis=1, keepdims=True) min_val = np.min(data, axis=1, keepdims=True) abs_max = np.where(np.abs(max_val) > np.abs(min_val), max_val, min_val) - scale = abs_max / -8.0 # if max == min, max may be clipped - quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(-8, 7).astype(np.int8) + scale = abs_max / float(qmin) # if max == min, max may be clipped + quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(qmin, qmax).astype(np.int8) return quantized_slice, scale @staticmethod - def _quant_slice_asymmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def _quant_slice_asymmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + qmax = (1 << bits) - 1 # 15 for 4-bit, 255 for 8-bit + mid = 1 << (bits - 1) # 8 for 4-bit, 128 for 8-bit min_val = np.minimum(data.min(axis=1, keepdims=True), 0) max_val = np.maximum(data.max(axis=1, keepdims=True), 0) - scale = (max_val - min_val) / 15.0 - zero_point = np.where(scale == 0, 8, -min_val / scale).round().clip(0, 15).astype(np.uint8) - quantized_slice = np.where(scale == 0, 8, data / scale + zero_point).round().clip(0, 15).astype(np.uint8) + scale = (max_val - min_val) / float(qmax) + zero_point = np.where(scale == 0, mid, -min_val / scale).round().clip(0, qmax).astype(np.uint8) + quantized_slice = np.where(scale == 0, mid, data / scale + zero_point).round().clip(0, qmax).astype(np.uint8) return quantized_slice, scale, zero_point @@ -382,8 +386,9 @@ def _quantize_ndarray( quantize_axis: int, block_size: int, is_symmetric: bool, + bits: int = 4, ) -> Optional[tuple[np.ndarray, np.ndarray, np.ndarray]]: - """Quantize ndarray data to int4 using numpy, return (quantized data, scales).""" + """Quantize ndarray data to int4/int8 using numpy, return (quantized data, scales).""" # Get the shape of the matrix m = 1 # dimension of the matrix before the quantize axis k = data.shape[quantize_axis] # dimension of the matrix along the quantize axis @@ -413,9 +418,9 @@ def _quantize_ndarray( block_slice = data_reshape[:, i:end_idx, :] zero_point_slice = None if is_symmetric: - quantized_slice_int8, scale_slice = self._quant_slice_symmetric(block_slice) + quantized_slice_int8, scale_slice = self._quant_slice_symmetric(block_slice, bits) else: - quantized_slice_int8, scale_slice, zero_point_slice = self._quant_slice_asymmetric(block_slice) + quantized_slice_int8, scale_slice, zero_point_slice = self._quant_slice_asymmetric(block_slice, bits) quant_data_int8[:, i:end_idx, :] = quantized_slice_int8 j = i // block_size @@ -423,26 +428,42 @@ def _quantize_ndarray( if not is_symmetric: zero_point_int8[:, j : (j + 1), :] = zero_point_slice - # pack int8 to int4 - # ORT GatherBlockQuantized uses unsigned int4 representation [0, 15] - # where zero_point=8 is implied for symmetric quantization. - # Convert signed int8 [-8, 7] to unsigned [0, 15] by adding 8. - if is_symmetric: - quant_data_int8 = (quant_data_int8.astype(np.int16) + 8).astype(np.uint8) - quant_data_int4 = self._pack_int8_to_int4(quant_data_int8) - zero_point_int4 = None - if not is_symmetric: - zero_point_int4 = self._pack_int8_to_int4(zero_point_int8) scales = scales.reshape(scales_shape) - # Reshape packed data to match original rank (GatherBlockQuantized requires rank > 1). - # Two int4 values are packed per uint8, so the quantize_axis dimension is halved. - packed_shape = list(data.shape) - packed_shape[quantize_axis] = (packed_shape[quantize_axis] + 1) // 2 - quant_data_int4 = quant_data_int4.reshape(packed_shape) - if zero_point_int4 is not None: - zp_shape = list(scales_shape) - zp_shape[quantize_axis] = (zp_shape[quantize_axis] + 1) // 2 - zero_point_int4 = zero_point_int4.reshape(zp_shape) - - return quant_data_int4, scales, zero_point_int4 + if bits <= 4: + # pack int8 to int4 + # ORT GatherBlockQuantized uses unsigned int4 representation [0, 15] + # where zero_point=8 is implied for symmetric quantization. + # Convert signed int8 [-8, 7] to unsigned [0, 15] by adding 8. + if is_symmetric: + quant_data_int8 = (quant_data_int8.astype(np.int16) + 8).astype(np.uint8) + quant_data_int4 = self._pack_int8_to_int4(quant_data_int8) + zero_point_int4 = None + if not is_symmetric: + zero_point_int4 = self._pack_int8_to_int4(zero_point_int8) + + # Reshape packed data to match original rank (GatherBlockQuantized requires rank > 1). + # Two int4 values are packed per uint8, so the quantize_axis dimension is halved. + packed_shape = list(data.shape) + packed_shape[quantize_axis] = (packed_shape[quantize_axis] + 1) // 2 + quant_data_int4 = quant_data_int4.reshape(packed_shape) + if zero_point_int4 is not None: + zp_shape = list(scales_shape) + zp_shape[quantize_axis] = (zp_shape[quantize_axis] + 1) // 2 + zero_point_int4 = zero_point_int4.reshape(zp_shape) + + return quant_data_int4, scales, zero_point_int4 + else: + # 8-bit: no packing needed, one value per byte. + # Convert signed int8 [-128, 127] to unsigned uint8 [0, 255] by adding 128. + if is_symmetric: + quant_data_uint8 = (quant_data_int8.astype(np.int16) + 128).astype(np.uint8) + else: + quant_data_uint8 = quant_data_int8 # already uint8 + + quant_data_uint8 = quant_data_uint8.reshape(data.shape) + zero_point_out = None + if not is_symmetric: + zero_point_out = zero_point_int8.reshape(scales_shape) + + return quant_data_uint8, scales, zero_point_out From 4d5283e43175f3b8b5a6a8f4c85093dd5b00c01b Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Tue, 3 Mar 2026 22:08:34 -0800 Subject: [PATCH 07/23] Add unit tests for Qwen3-VL graph surgery and quantization passes --- .../onnx/test_cast_chain_elimination.py | 135 +++++++++ test/passes/onnx/test_graph_surgeries.py | 270 ++++++++++++++++++ test/passes/onnx/test_rtn_quantization.py | 145 ++++++++++ 3 files changed, 550 insertions(+) create mode 100644 test/passes/onnx/test_cast_chain_elimination.py diff --git a/test/passes/onnx/test_cast_chain_elimination.py b/test/passes/onnx/test_cast_chain_elimination.py new file mode 100644 index 0000000000..25ae896d8f --- /dev/null +++ b/test/passes/onnx/test_cast_chain_elimination.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for OnnxCastChainElimination pass.""" + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +from onnx import TensorProto, helper + +from olive.model import ONNXModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.cast_chain_elimination import OnnxCastChainElimination, _ensure_com_microsoft_opset + + +class TestEnsureComMicrosoftOpset: + """Unit tests for _ensure_com_microsoft_opset helper.""" + + def test_adds_opset_when_missing(self): + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + assert not any(op.domain == "com.microsoft" for op in model.opset_import) + + _ensure_com_microsoft_opset(model) + + assert any(op.domain == "com.microsoft" for op in model.opset_import) + + def test_does_not_duplicate_opset(self): + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("com.microsoft", 1)], + ) + + _ensure_com_microsoft_opset(model) + + ms_opsets = [op for op in model.opset_import if op.domain == "com.microsoft"] + assert len(ms_opsets) == 1 + + def test_adds_opset_to_functions(self): + func = onnx.FunctionProto() + func.name = "test_func" + func.domain = "test.domain" + func.opset_import.append(helper.make_opsetid("", 17)) + + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + model.functions.append(func) + + _ensure_com_microsoft_opset(model) + + func_domains = {op.domain for op in model.functions[0].opset_import} + assert "com.microsoft" in func_domains + + def test_skips_function_with_existing_opset(self): + func = onnx.FunctionProto() + func.name = "test_func" + func.domain = "test.domain" + func.opset_import.append(helper.make_opsetid("", 17)) + func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + model.functions.append(func) + + _ensure_com_microsoft_opset(model) + + ms_opsets = [op for op in model.functions[0].opset_import if op.domain == "com.microsoft"] + assert len(ms_opsets) == 1 + + +class TestOnnxCastChainElimination: + """Integration tests for OnnxCastChainElimination pass.""" + + @pytest.fixture + def cast_chain_model_path(self, tmp_path): + """Model with a redundant Cast chain: fp32 → fp16 → fp32.""" + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + cast_to_fp16 = helper.make_node("Cast", ["X"], ["x_fp16"], to=TensorProto.FLOAT16) + cast_back = helper.make_node("Cast", ["x_fp16"], ["Y"], to=TensorProto.FLOAT) + + graph = helper.make_graph([cast_to_fp16, cast_back], "cast_chain", [input_tensor], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + onnx.checker.check_model(model) + + path = tmp_path / "cast_chain.onnx" + onnx.save(model, str(path)) + return path + + def test_eliminates_redundant_cast_chain(self, cast_chain_model_path, tmp_path): + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict(OnnxCastChainElimination, {}, disable_search=True) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + result_model = onnx.load(output.model_path) + # The pass should produce a valid, runnable model. + # Actual cast elimination depends on the ORT version; at minimum the + # output graph must not have *more* nodes than the input. + assert len(result_model.graph.node) <= 2 + + def test_opset_fixup_applied(self, cast_chain_model_path, tmp_path): + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict( + OnnxCastChainElimination, + {"fix_opset": True, "enable_cast_chain_elimination": False}, + disable_search=True, + ) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + result_model = onnx.load(output.model_path) + assert any(op.domain == "com.microsoft" for op in result_model.opset_import) + + def test_numerical_correctness(self, cast_chain_model_path, tmp_path): + """Optimized model should produce the same output as the original.""" + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict(OnnxCastChainElimination, {}, disable_search=True) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + x = np.random.randn(2, 4).astype(np.float32) + + orig_sess = ort.InferenceSession(str(cast_chain_model_path), providers=["CPUExecutionProvider"]) + opt_sess = ort.InferenceSession(output.model_path, providers=["CPUExecutionProvider"]) + + orig_out = orig_sess.run(None, {"X": x})[0] + opt_out = opt_sess.run(None, {"X": x})[0] + np.testing.assert_allclose(orig_out, opt_out, atol=1e-3) diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index b1202c45b5..d3c2a8124a 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -2295,3 +2295,273 @@ def test_packed_attention_to_packed_mha(tmp_path): op_types = [node.op_type for node in output_model_def.graph.node] assert OpType.PackedAttention not in op_types assert OpType.PackedMultiHeadAttention in op_types + + +# ── GemmToMatMulAdd ────────────────────────────────────────────────────── + + +def test_gemm_to_matmul_add(tmp_path): + """Gemm(A, B, C) with transB=1 → MatMul(A, B^T) + Add(C).""" + input_tensor = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + b_data = np.random.randn(4, 3).astype(np.float32) # transB shape + c_data = np.random.randn(4).astype(np.float32) + + initializers = [ + numpy_helper.from_array(b_data, name="B"), + numpy_helper.from_array(c_data, name="C"), + ] + + gemm_node = helper.make_node( + "Gemm", ["A", "B", "C"], ["Y"], name="Gemm0", transB=1, alpha=1.0, beta=1.0 + ) + + graph = helper.make_graph([gemm_node], "test", [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + onnx.checker.check_model(model) + + model_path = tmp_path / "gemm.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Gemm" not in op_types + assert "MatMul" in op_types + assert "Add" in op_types + + # Numerical check + a = np.random.randn(2, 3).astype(np.float32) + expected = a @ b_data.T + c_data + + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + actual = sess.run(None, {"A": a})[0] + np.testing.assert_allclose(actual, expected, atol=1e-5) + + +def test_gemm_to_matmul_add_no_bias(tmp_path): + """Gemm(A, B) without bias → MatMul only.""" + input_tensor = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + b_data = np.eye(3, dtype=np.float32) + initializers = [numpy_helper.from_array(b_data, name="B")] + + gemm_node = helper.make_node("Gemm", ["A", "B"], ["Y"], name="Gemm0") + + graph = helper.make_graph([gemm_node], "test", [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + onnx.checker.check_model(model) + + model_path = tmp_path / "gemm_no_bias.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Gemm" not in op_types + assert "MatMul" in op_types + assert "Add" not in op_types + + +def test_gemm_to_matmul_add_skips_non_unit_alpha(tmp_path): + """Gemm with alpha != 1.0 should be left unchanged.""" + input_tensor = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + b_data = np.eye(3, dtype=np.float32) + initializers = [numpy_helper.from_array(b_data, name="B")] + + gemm_node = helper.make_node("Gemm", ["A", "B"], ["Y"], name="Gemm0", alpha=2.0) + + graph = helper.make_graph([gemm_node], "test", [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + model_path = tmp_path / "gemm_alpha.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Gemm" in op_types, "Gemm with alpha != 1.0 should be preserved" + + +# ── ReciprocalMulToDiv ─────────────────────────────────────────────────── + + +def test_reciprocal_mul_to_div(tmp_path): + """Reciprocal(x) * a → Div(a, x).""" + input_x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + input_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + nodes = [ + helper.make_node("Reciprocal", ["X"], ["recip_out"], name="Recip"), + helper.make_node("Mul", ["A", "recip_out"], ["Y"], name="MulRecip"), + ] + + graph = helper.make_graph(nodes, "test", [input_x, input_a], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + onnx.checker.check_model(model) + + model_path = tmp_path / "recip_mul.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Reciprocal" not in op_types + assert "Mul" not in op_types + assert "Div" in op_types + + # Numerical check: Div(a, x) == a * (1/x) + x = np.random.randn(2, 4).astype(np.float32) + 2.0 # avoid near-zero + a = np.random.randn(2, 4).astype(np.float32) + expected = a / x + + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + actual = sess.run(None, {"X": x, "A": a})[0] + np.testing.assert_allclose(actual, expected, atol=1e-5) + + +def test_reciprocal_mul_to_div_reversed_order(tmp_path): + """Mul(recip_out, a) — Reciprocal output on the left side.""" + input_x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4]) + input_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4]) + + nodes = [ + helper.make_node("Reciprocal", ["X"], ["recip_out"], name="Recip"), + helper.make_node("Mul", ["recip_out", "A"], ["Y"], name="MulRecip"), + ] + + graph = helper.make_graph(nodes, "test", [input_x, input_a], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + model_path = tmp_path / "recip_mul_rev.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Div" in op_types + + # Check Div inputs: Div(A, X) regardless of Mul operand order + div_node = next(n for n in out_proto.graph.node if n.op_type == "Div") + assert div_node.input[0] == "A" + assert div_node.input[1] == "X" + + +# ── DeduplicateSubgraphInitializers ────────────────────────────────────── + + +def test_deduplicate_subgraph_initializers(tmp_path): + """Duplicate initializers inside a Loop subgraph should be removed.""" + # Build a minimal Loop body with duplicate initializers + body_input = helper.make_tensor_value_info("i", TensorProto.INT64, []) + body_cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + body_cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + body_out = helper.make_tensor_value_info("body_out", TensorProto.FLOAT, [2]) + + const_data = np.array([1.0, 2.0], dtype=np.float32) + init1 = numpy_helper.from_array(const_data, name="dup_init") + init2 = numpy_helper.from_array(const_data, name="dup_init") # duplicate + + body_node = helper.make_node("Identity", ["dup_init"], ["body_out"]) + cond_true = numpy_helper.from_array(np.array(True), name="cond_true") + cond_node = helper.make_node("Identity", ["cond_true"], ["cond_out"]) + + body_graph = helper.make_graph( + [cond_node, body_node], "body", [body_input, body_cond_in], [body_cond_out, body_out], + initializer=[init1, init2, cond_true], + ) + + # Outer graph with Loop node + trip_count = helper.make_tensor_value_info("trip_count", TensorProto.INT64, []) + cond = helper.make_tensor_value_info("cond", TensorProto.BOOL, []) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 2]) + + loop_node = helper.make_node("Loop", ["trip_count", "cond"], ["output"], body=body_graph) + + main_graph = helper.make_graph([loop_node], "main", [trip_count, cond], [output]) + model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 17)]) + + model_path = tmp_path / "loop_dup.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "DeduplicateSubgraphInitializers"}]}, + disable_search=True, + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + # Check the Loop body: should have only 1 copy of "dup_init" + loop = next(n for n in out_proto.graph.node if n.op_type == "Loop") + body = loop.attribute[0].g + dup_inits = [i for i in body.initializer if i.name == "dup_init"] + assert len(dup_inits) == 1, f"Expected 1, got {len(dup_inits)} copies of dup_init" + + +# ── DeduplicateNodes ───────────────────────────────────────────────────── + + +def test_deduplicate_nodes(tmp_path): + """Nodes that produce the same output tensor name should be deduplicated.""" + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 4]) + + # Two Cast nodes with the same output tensor name (simulating float16.py bug) + cast1 = helper.make_node("Cast", ["X"], ["cast_out"], name="Cast1", to=TensorProto.FLOAT16) + cast2 = helper.make_node("Cast", ["X"], ["cast_out"], name="Cast2", to=TensorProto.FLOAT16) + identity = helper.make_node("Identity", ["cast_out"], ["Y"], name="Id") + + graph = helper.make_graph([cast1, cast2, identity], "test", [input_tensor], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + + model_path = tmp_path / "dup_nodes.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, {"surgeries": [{"surgeon": "DeduplicateNodes"}]}, disable_search=True + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + cast_nodes = [n for n in out_proto.graph.node if n.op_type == "Cast"] + assert len(cast_nodes) == 1, f"Expected 1 Cast, got {len(cast_nodes)}" + + # The remaining model should still be runnable + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + x = np.random.randn(2, 4).astype(np.float32) + result = sess.run(None, {"X": x})[0] + np.testing.assert_allclose(result, x.astype(np.float16), atol=1e-3) diff --git a/test/passes/onnx/test_rtn_quantization.py b/test/passes/onnx/test_rtn_quantization.py index 2dc6f333e9..413396fb2f 100644 --- a/test/passes/onnx/test_rtn_quantization.py +++ b/test/passes/onnx/test_rtn_quantization.py @@ -204,3 +204,148 @@ def test_rtn_quantization_with_exclusion(self, matmul_model_path, tmp_path): assert not found_matmul_nbits, "MatMulNBits node found despite exclusion" assert found_original_matmul, "Original MatMul node should still exist when excluded" + + @pytest.mark.parametrize("is_symmetric", [True, False]) + def test_rtn_quantization_gather_8bit(self, gather_model_path, tmp_path, is_symmetric): + """8-bit Gather quantization should produce GatherBlockQuantized with bits=8.""" + olive_model = ONNXModelHandler(model_path=str(gather_model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", + ) + pass_config = {"bits": 8, "block_size": 128, "axis": 0, "is_symmetric": is_symmetric} + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec + ) + + output_path = tmp_path / "quantized_gather_8bit.onnx" + quantized_model = p.run(olive_model, output_path) + + assert os.path.exists(quantized_model.model_path) + + ir_model = ir.load(quantized_model.model_path) + + found = False + for node in ir_model.graph.all_nodes(): + if node.op_type == str(OpType.GatherBlockQuantized): + found = True + assert node.domain == MSFT_DOMAIN + # bits attribute must be 8 + assert any( + attr.name == "bits" and attr.value == 8 + for attr in node.attributes.values() + ), "GatherBlockQuantized should have bits=8" + # quantize_axis must be last dimension (data_rank - 1) + assert any( + attr.name == "quantize_axis" and attr.value == 1 + for attr in node.attributes.values() + ), "quantize_axis should be forced to last dim (1 for 2-D embedding)" + break + + assert found, "No GatherBlockQuantized node found for 8-bit quantization" + + def test_rtn_quantization_gather_quantize_axis_forced_to_last_dim(self, gather_model_path, tmp_path): + """Regardless of axis config, gather quantize_axis is forced to data_rank - 1.""" + olive_model = ONNXModelHandler(model_path=str(gather_model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", + ) + # Set axis=0, but the code should force quantize_axis to last dim + pass_config = {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True} + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec + ) + + output_path = tmp_path / "quantized_gather_axis.onnx" + quantized_model = p.run(olive_model, output_path) + + ir_model = ir.load(quantized_model.model_path) + + for node in ir_model.graph.all_nodes(): + if node.op_type == str(OpType.GatherBlockQuantized): + qa = [attr for attr in node.attributes.values() if attr.name == "quantize_axis"] + assert len(qa) == 1 + assert qa[0].value == 1, ( + f"quantize_axis should be 1 (last dim of 2-D data), got {qa[0].value}" + ) + break + + def test_rtn_quantization_shared_gather_weights(self, tmp_path): + """Two Gather nodes sharing the same weight should not duplicate initializers.""" + data_shape = [100, 64] + data_tensor = np.random.randn(*data_shape).astype(np.float32) + data_name = "shared_data" + + data_init = onnx.helper.make_tensor( + name=data_name, data_type=onnx.TensorProto.FLOAT, dims=data_shape, + vals=data_tensor.flatten().tolist(), + ) + indices1 = onnx.helper.make_tensor_value_info("indices1", onnx.TensorProto.INT64, [1, 5]) + indices2 = onnx.helper.make_tensor_value_info("indices2", onnx.TensorProto.INT64, [1, 5]) + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [1, 5, 64]) + out2 = onnx.helper.make_tensor_value_info("out2", onnx.TensorProto.FLOAT, [1, 5, 64]) + + gather1 = onnx.helper.make_node("Gather", [data_name, "indices1"], ["out1"], name="Gather1") + gather2 = onnx.helper.make_node("Gather", [data_name, "indices2"], ["out2"], name="Gather2") + + graph = onnx.helper.make_graph( + [gather1, gather2], "shared_weight_test", + [indices1, indices2], [out1, out2], + initializer=[data_init], + ) + model = onnx.helper.make_model(graph, producer_name="olive-test") + model.opset_import[0].version = 13 + + model_path = tmp_path / "shared_gather.onnx" + onnx.save(model, str(model_path)) + + olive_model = ONNXModelHandler(model_path=str(model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", execution_provider="CPUExecutionProvider", + ) + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, + {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + output_path = tmp_path / "shared_gather_quantized.onnx" + quantized_model = p.run(olive_model, output_path) + + ir_model = ir.load(quantized_model.model_path) + + # Both nodes should be GatherBlockQuantized + gbq_nodes = [n for n in ir_model.graph.all_nodes() if n.op_type == str(OpType.GatherBlockQuantized)] + assert len(gbq_nodes) == 2, f"Expected 2 GatherBlockQuantized nodes, got {len(gbq_nodes)}" + + # The quantized data inputs (first input) should refer to the same name + quant_data_names = [n.inputs[0].name for n in gbq_nodes] + assert quant_data_names[0] == quant_data_names[1], ( + f"Shared weight should produce same quantized initializer name: {quant_data_names}" + ) + + def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, tmp_path): + """After quantization, original FP32 weight initializers should be removed.""" + olive_model = ONNXModelHandler(model_path=str(matmul_model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", execution_provider="CPUExecutionProvider", + ) + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, + {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + output_path = tmp_path / "unused_init_test.onnx" + quantized_model = p.run(olive_model, output_path) + + ir_model = ir.load(quantized_model.model_path) + + # The original "weight" initializer should be gone + init_names = set(ir_model.graph.initializers.keys()) + assert "weight" not in init_names, ( + f"Original FP32 'weight' initializer should have been removed, found: {init_names}" + ) From 9fc9bd374a950f0eb5ce6c56dc02280b21a8851d Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Tue, 3 Mar 2026 22:55:02 -0800 Subject: [PATCH 08/23] Fix lintrunner warnings: rename uppercase variables (N806), add TODO author (TD002), fix formatting --- olive/passes/onnx/common.py | 4 +-- olive/passes/onnx/graph_surgeries.py | 49 ++++++++++++++-------------- 2 files changed, 27 insertions(+), 26 deletions(-) diff --git a/olive/passes/onnx/common.py b/olive/passes/onnx/common.py index 4b6fcd0ecc..c207881420 100644 --- a/olive/passes/onnx/common.py +++ b/olive/passes/onnx/common.py @@ -13,8 +13,8 @@ from onnx import external_data_helper from onnxscript import ir -# TODO: Remove try/except once onnxscript >= 0.2.0 (which exports FOLDED_FROM_KEY) is the minimum required version. -# After that, replace with: from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY +# TODO(sunghcho): Remove try/except once onnxscript >= 0.2.0 (which exports FOLDED_FROM_KEY) is the minimum +# required version. After that, replace with: from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY try: from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY except ImportError: diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 6fa4fb0355..fb4518934b 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -1258,51 +1258,55 @@ def __call__(self, model: ModelProto): continue alpha = beta = 1.0 - transA = transB = 0 + trans_a = trans_b = 0 for attr in node.attribute: if attr.name == "alpha": alpha = attr.f elif attr.name == "beta": beta = attr.f elif attr.name == "transA": - transA = attr.i + trans_a = attr.i elif attr.name == "transB": - transB = attr.i + trans_b = attr.i - if alpha != 1.0 or beta != 1.0 or transA != 0: + if alpha != 1.0 or beta != 1.0 or trans_a != 0: continue - A, B = node.input[0], node.input[1] - C = node.input[2] if len(node.input) > 2 else None - Y = node.output[0] + inp_a, inp_b = node.input[0], node.input[1] + inp_c = node.input[2] if len(node.input) > 2 else None + out_y = node.output[0] - if transB: - if B in initializer_map: - init = initializer_map[B] + if trans_b: + if inp_b in initializer_map: + init = initializer_map[inp_b] w_t = numpy_helper.to_array(init).T.copy() - new_init = numpy_helper.from_array(w_t, name=B) + new_init = numpy_helper.from_array(w_t, name=inp_b) for i, existing in enumerate(graph.initializer): - if existing.name == B: + if existing.name == inp_b: graph.initializer[i].CopyFrom(new_init) break - matmul_B = B + matmul_rhs = inp_b else: transpose_out = f"{node.name}_transpose_B" nodes_to_add.append( - helper.make_node("Transpose", [B], [transpose_out], name=f"{node.name}_Transpose", perm=[1, 0]) + helper.make_node( + "Transpose", [inp_b], [transpose_out], name=f"{node.name}_Transpose", perm=[1, 0] + ) ) - matmul_B = transpose_out + matmul_rhs = transpose_out else: - matmul_B = B + matmul_rhs = inp_b - if C: + if inp_c: matmul_out = f"{node.name}_matmul_out" nodes_to_add.append( - helper.make_node("MatMul", [A, matmul_B], [matmul_out], name=f"{node.name}_MatMul") + helper.make_node("MatMul", [inp_a, matmul_rhs], [matmul_out], name=f"{node.name}_MatMul") ) - nodes_to_add.append(helper.make_node("Add", [matmul_out, C], [Y], name=f"{node.name}_Add")) + nodes_to_add.append(helper.make_node("Add", [matmul_out, inp_c], [out_y], name=f"{node.name}_Add")) else: - nodes_to_add.append(helper.make_node("MatMul", [A, matmul_B], [Y], name=f"{node.name}_MatMul")) + nodes_to_add.append( + helper.make_node("MatMul", [inp_a, matmul_rhs], [out_y], name=f"{node.name}_MatMul") + ) nodes_to_remove.append(node) @@ -2164,10 +2168,7 @@ def __call__(self, model: ModelProto): recip_output = node.output[0] # Find Mul consumers of this Reciprocal - mul_nodes = [ - n for n in model.graph.node - if n.op_type == "Mul" and recip_output in n.input - ] + mul_nodes = [n for n in model.graph.node if n.op_type == "Mul" and recip_output in n.input] for mul_node in mul_nodes: # Identify the other operand (not from Reciprocal) From 74b257c6b429c645264b7ebc0889dd187bed5584 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Tue, 3 Mar 2026 23:43:27 -0800 Subject: [PATCH 09/23] Fix ruff formatting, int4 packing bug, and test assertion - Apply ruff format to 4 files (cast_chain_elimination.py, rtn_quantization.py, test_graph_surgeries.py, test_rtn_quantization.py) - Fix _pack_int8_to_int4 reshape bug: replace global flatten+pack with axis-aware _pack_int4_along_axis that correctly packs zero_point when k_blocks is small (e.g. 1), avoiding ValueError on reshape - Fix test_rtn_quantization_pass_gather assertion: GatherBlockQuantized always uses quantize_axis=data_rank-1, not pass_config['axis'] --- olive/passes/onnx/cast_chain_elimination.py | 1 + olive/passes/onnx/rtn_quantization.py | 43 ++++++++++++--------- test/passes/onnx/test_graph_surgeries.py | 33 ++++++---------- test/passes/onnx/test_rtn_quantization.py | 41 ++++++++++---------- 4 files changed, 57 insertions(+), 61 deletions(-) diff --git a/olive/passes/onnx/cast_chain_elimination.py b/olive/passes/onnx/cast_chain_elimination.py index 7eb65ceead..4fa2368aeb 100644 --- a/olive/passes/onnx/cast_chain_elimination.py +++ b/olive/passes/onnx/cast_chain_elimination.py @@ -19,6 +19,7 @@ 2. Runs ORT ``ORT_ENABLE_BASIC`` optimization with Cast chain elimination explicitly enabled to produce a cleaned-up model. """ + import logging from pathlib import Path diff --git a/olive/passes/onnx/rtn_quantization.py b/olive/passes/onnx/rtn_quantization.py index 78aa3adfcf..d4742aeebc 100644 --- a/olive/passes/onnx/rtn_quantization.py +++ b/olive/passes/onnx/rtn_quantization.py @@ -127,8 +127,7 @@ def _quantize_model( if node.op_type == str(OpType.Gather) and bits not in (4, 8): logger.warning( - "Gather quantization is only implemented for 4-bit and 8-bit. " - "Skip node %s (bits=%d).", + "Gather quantization is only implemented for 4-bit and 8-bit. Skip node %s (bits=%d).", node_name, bits, ) @@ -145,9 +144,7 @@ def _quantize_model( if input_value.name in globally_registered: # Already registered by a previous node (shared weight), # replace with the existing initializer. - ir.convenience.replace_all_uses_with( - input_value, globally_registered[input_value.name] - ) + ir.convenience.replace_all_uses_with(input_value, globally_registered[input_value.name]) elif input_value.name not in registered: initializer_graph.register_initializer(input_value) registered[input_value.name] = input_value @@ -346,8 +343,8 @@ def _qbits_block_quant( @staticmethod def _quant_slice_symmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray, np.ndarray]: - qmin = -(1 << (bits - 1)) # -8 for 4-bit, -128 for 8-bit - qmax = (1 << (bits - 1)) - 1 # 7 for 4-bit, 127 for 8-bit + qmin = -(1 << (bits - 1)) # -8 for 4-bit, -128 for 8-bit + qmax = (1 << (bits - 1)) - 1 # 7 for 4-bit, 127 for 8-bit max_val = np.max(data, axis=1, keepdims=True) min_val = np.min(data, axis=1, keepdims=True) abs_max = np.where(np.abs(max_val) > np.abs(min_val), max_val, min_val) @@ -359,8 +356,8 @@ def _quant_slice_symmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray, @staticmethod def _quant_slice_asymmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - qmax = (1 << bits) - 1 # 15 for 4-bit, 255 for 8-bit - mid = 1 << (bits - 1) # 8 for 4-bit, 128 for 8-bit + qmax = (1 << bits) - 1 # 15 for 4-bit, 255 for 8-bit + mid = 1 << (bits - 1) # 8 for 4-bit, 128 for 8-bit min_val = np.minimum(data.min(axis=1, keepdims=True), 0) max_val = np.maximum(data.max(axis=1, keepdims=True), 0) @@ -371,14 +368,20 @@ def _quant_slice_asymmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray return quantized_slice, scale, zero_point @staticmethod - def _pack_int8_to_int4(data: np.ndarray) -> np.ndarray: - """Pack int8 data to int4 and store in uint8 ndarray.""" - data_flat = data.reshape(-1) - if len(data_flat) % 2 != 0: - data_flat = np.append(data_flat, 0) - quant_data_int4 = (data_flat[::2] & 0xF) | ((data_flat[1::2] & 0xF) << 4) + def _pack_int4_along_axis(data: np.ndarray, axis: int = 1) -> np.ndarray: + """Pack pairs of int4 values into uint8 along the specified axis. - return quant_data_int4.astype("uint8") + Unlike a flat pack, this correctly handles cases where the packing dimension is small + (e.g., zero_point with k_blocks=1) by only pairing values within the same axis slice. + """ + k = data.shape[axis] + if k % 2 != 0: + pad_width = [(0, 0)] * len(data.shape) + pad_width[axis] = (0, 1) + data = np.pad(data, pad_width) + low = np.take(data, range(0, data.shape[axis], 2), axis=axis) + high = np.take(data, range(1, data.shape[axis], 2), axis=axis) + return ((low & 0xF) | ((high & 0xF) << 4)).astype("uint8") def _quantize_ndarray( self, @@ -437,13 +440,15 @@ def _quantize_ndarray( # Convert signed int8 [-8, 7] to unsigned [0, 15] by adding 8. if is_symmetric: quant_data_int8 = (quant_data_int8.astype(np.int16) + 8).astype(np.uint8) - quant_data_int4 = self._pack_int8_to_int4(quant_data_int8) + + # Pack along axis=1 (the quantize_axis in the 3D view: m, k, n). + # This ensures packing pairs values within the same row, not across rows. + quant_data_int4 = self._pack_int4_along_axis(quant_data_int8, axis=1) zero_point_int4 = None if not is_symmetric: - zero_point_int4 = self._pack_int8_to_int4(zero_point_int8) + zero_point_int4 = self._pack_int4_along_axis(zero_point_int8, axis=1) # Reshape packed data to match original rank (GatherBlockQuantized requires rank > 1). - # Two int4 values are packed per uint8, so the quantize_axis dimension is halved. packed_shape = list(data.shape) packed_shape[quantize_axis] = (packed_shape[quantize_axis] + 1) // 2 quant_data_int4 = quant_data_int4.reshape(packed_shape) diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index d3c2a8124a..d852d8664c 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -2313,9 +2313,7 @@ def test_gemm_to_matmul_add(tmp_path): numpy_helper.from_array(c_data, name="C"), ] - gemm_node = helper.make_node( - "Gemm", ["A", "B", "C"], ["Y"], name="Gemm0", transB=1, alpha=1.0, beta=1.0 - ) + gemm_node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], name="Gemm0", transB=1, alpha=1.0, beta=1.0) graph = helper.make_graph([gemm_node], "test", [input_tensor], [output_tensor], initializer=initializers) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) @@ -2325,9 +2323,7 @@ def test_gemm_to_matmul_add(tmp_path): onnx.save(model, str(model_path)) input_model = ONNXModelHandler(model_path=str(model_path)) - p = create_pass_from_dict( - GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True - ) + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True) output_model = p.run(input_model, str(tmp_path / "out")) out_proto = output_model.load_model() @@ -2363,9 +2359,7 @@ def test_gemm_to_matmul_add_no_bias(tmp_path): onnx.save(model, str(model_path)) input_model = ONNXModelHandler(model_path=str(model_path)) - p = create_pass_from_dict( - GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True - ) + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True) output_model = p.run(input_model, str(tmp_path / "out")) out_proto = output_model.load_model() @@ -2392,9 +2386,7 @@ def test_gemm_to_matmul_add_skips_non_unit_alpha(tmp_path): onnx.save(model, str(model_path)) input_model = ONNXModelHandler(model_path=str(model_path)) - p = create_pass_from_dict( - GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True - ) + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True) output_model = p.run(input_model, str(tmp_path / "out")) out_proto = output_model.load_model() @@ -2424,9 +2416,7 @@ def test_reciprocal_mul_to_div(tmp_path): onnx.save(model, str(model_path)) input_model = ONNXModelHandler(model_path=str(model_path)) - p = create_pass_from_dict( - GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True - ) + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True) output_model = p.run(input_model, str(tmp_path / "out")) out_proto = output_model.load_model() @@ -2463,9 +2453,7 @@ def test_reciprocal_mul_to_div_reversed_order(tmp_path): onnx.save(model, str(model_path)) input_model = ONNXModelHandler(model_path=str(model_path)) - p = create_pass_from_dict( - GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True - ) + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True) output_model = p.run(input_model, str(tmp_path / "out")) out_proto = output_model.load_model() @@ -2498,7 +2486,10 @@ def test_deduplicate_subgraph_initializers(tmp_path): cond_node = helper.make_node("Identity", ["cond_true"], ["cond_out"]) body_graph = helper.make_graph( - [cond_node, body_node], "body", [body_input, body_cond_in], [body_cond_out, body_out], + [cond_node, body_node], + "body", + [body_input, body_cond_in], + [body_cond_out, body_out], initializer=[init1, init2, cond_true], ) @@ -2551,9 +2542,7 @@ def test_deduplicate_nodes(tmp_path): onnx.save(model, str(model_path)) input_model = ONNXModelHandler(model_path=str(model_path)) - p = create_pass_from_dict( - GraphSurgeries, {"surgeries": [{"surgeon": "DeduplicateNodes"}]}, disable_search=True - ) + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "DeduplicateNodes"}]}, disable_search=True) output_model = p.run(input_model, str(tmp_path / "out")) out_proto = output_model.load_model() diff --git a/test/passes/onnx/test_rtn_quantization.py b/test/passes/onnx/test_rtn_quantization.py index 413396fb2f..bfdd49fa87 100644 --- a/test/passes/onnx/test_rtn_quantization.py +++ b/test/passes/onnx/test_rtn_quantization.py @@ -154,6 +154,8 @@ def test_rtn_quantization_pass_gather(self, gather_model_path, tmp_path, is_symm ir_model = ir.load(quantized_model.model_path) # Assert + # ORT GatherBlockQuantized requires quantize_axis == last dimension (data_rank - 1). + # The gather model fixture uses 2D data [100, 64], so quantize_axis = 1. found_gather_block_quantized = False for node in ir_model.graph.all_nodes(): if node.op_type == str(OpType.GatherBlockQuantized): @@ -163,10 +165,7 @@ def test_rtn_quantization_pass_gather(self, gather_model_path, tmp_path, is_symm attr.name == "block_size" and attr.value == pass_config["block_size"] for attr in node.attributes.values() ) - assert any( - attr.name == "quantize_axis" and attr.value == pass_config["axis"] - for attr in node.attributes.values() - ) + assert any(attr.name == "quantize_axis" and attr.value == 1 for attr in node.attributes.values()) break assert found_gather_block_quantized, "No GatherBlockQuantized node found in quantized model" @@ -231,15 +230,13 @@ def test_rtn_quantization_gather_8bit(self, gather_model_path, tmp_path, is_symm found = True assert node.domain == MSFT_DOMAIN # bits attribute must be 8 - assert any( - attr.name == "bits" and attr.value == 8 - for attr in node.attributes.values() - ), "GatherBlockQuantized should have bits=8" + assert any(attr.name == "bits" and attr.value == 8 for attr in node.attributes.values()), ( + "GatherBlockQuantized should have bits=8" + ) # quantize_axis must be last dimension (data_rank - 1) - assert any( - attr.name == "quantize_axis" and attr.value == 1 - for attr in node.attributes.values() - ), "quantize_axis should be forced to last dim (1 for 2-D embedding)" + assert any(attr.name == "quantize_axis" and attr.value == 1 for attr in node.attributes.values()), ( + "quantize_axis should be forced to last dim (1 for 2-D embedding)" + ) break assert found, "No GatherBlockQuantized node found for 8-bit quantization" @@ -266,9 +263,7 @@ def test_rtn_quantization_gather_quantize_axis_forced_to_last_dim(self, gather_m if node.op_type == str(OpType.GatherBlockQuantized): qa = [attr for attr in node.attributes.values() if attr.name == "quantize_axis"] assert len(qa) == 1 - assert qa[0].value == 1, ( - f"quantize_axis should be 1 (last dim of 2-D data), got {qa[0].value}" - ) + assert qa[0].value == 1, f"quantize_axis should be 1 (last dim of 2-D data), got {qa[0].value}" break def test_rtn_quantization_shared_gather_weights(self, tmp_path): @@ -278,7 +273,9 @@ def test_rtn_quantization_shared_gather_weights(self, tmp_path): data_name = "shared_data" data_init = onnx.helper.make_tensor( - name=data_name, data_type=onnx.TensorProto.FLOAT, dims=data_shape, + name=data_name, + data_type=onnx.TensorProto.FLOAT, + dims=data_shape, vals=data_tensor.flatten().tolist(), ) indices1 = onnx.helper.make_tensor_value_info("indices1", onnx.TensorProto.INT64, [1, 5]) @@ -290,8 +287,10 @@ def test_rtn_quantization_shared_gather_weights(self, tmp_path): gather2 = onnx.helper.make_node("Gather", [data_name, "indices2"], ["out2"], name="Gather2") graph = onnx.helper.make_graph( - [gather1, gather2], "shared_weight_test", - [indices1, indices2], [out1, out2], + [gather1, gather2], + "shared_weight_test", + [indices1, indices2], + [out1, out2], initializer=[data_init], ) model = onnx.helper.make_model(graph, producer_name="olive-test") @@ -302,7 +301,8 @@ def test_rtn_quantization_shared_gather_weights(self, tmp_path): olive_model = ONNXModelHandler(model_path=str(model_path)) accelerator_spec = AcceleratorSpec( - accelerator_type="CPU", execution_provider="CPUExecutionProvider", + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", ) p = create_pass_from_dict( OnnxBlockWiseRtnQuantization, @@ -330,7 +330,8 @@ def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, t """After quantization, original FP32 weight initializers should be removed.""" olive_model = ONNXModelHandler(model_path=str(matmul_model_path)) accelerator_spec = AcceleratorSpec( - accelerator_type="CPU", execution_provider="CPUExecutionProvider", + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", ) p = create_pass_from_dict( OnnxBlockWiseRtnQuantization, From 62544da40f714fa519ce2419d8e376bbc204ff48 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Wed, 4 Mar 2026 00:19:50 -0800 Subject: [PATCH 10/23] Add linkcheck_ignore for broken intel/neural-compressor URL The upstream tuning_strategies.md page no longer exists, causing the Sphinx linkcheck to fail with -W (warnings-as-errors). --- docs/source/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index fb441ece3c..d3ee0557ca 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -132,6 +132,8 @@ r"https://docs.qualcomm.com/*", # TODO(jambayk): remove this when the issue is fixed r"https://www.intel.com/*", + # TODO(sunghcho): remove this when the upstream repo restores the page + r"https://github.com/intel/neural-compressor/*", # TODO(team): html files are generated after doc build. Linkcheck doesn't work for them. # Remove this when linkcheck works for html files. r"^(?!https).*\.html$", From 3d0029c864c6e935f57c41317b33f45e00cbe5ac Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Thu, 5 Mar 2026 21:55:22 -0800 Subject: [PATCH 11/23] Remove neural-compressor linkcheck_ignore (fixed upstream in #2351) --- docs/source/conf.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index d3ee0557ca..fb441ece3c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -132,8 +132,6 @@ r"https://docs.qualcomm.com/*", # TODO(jambayk): remove this when the issue is fixed r"https://www.intel.com/*", - # TODO(sunghcho): remove this when the upstream repo restores the page - r"https://github.com/intel/neural-compressor/*", # TODO(team): html files are generated after doc build. Linkcheck doesn't work for them. # Remove this when linkcheck works for html files. r"^(?!https).*\.html$", From 448e8a273b8c97fcd99a3c82be974b9e431324ca Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Wed, 11 Mar 2026 22:18:05 -0700 Subject: [PATCH 12/23] Trigger CI rebuild From b41c25fe5c95a3868ed45929cf7842b56b53377c Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Wed, 11 Mar 2026 22:26:28 -0700 Subject: [PATCH 13/23] Trigger CI rebuild (lint) From a35f6e9939d5a661f8d6c01b61dcfc2126cd3ccd Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Wed, 11 Mar 2026 22:37:12 -0700 Subject: [PATCH 14/23] Trigger CI rebuild (all green) From 9846f3166b1dc8169b7e64f37f657b93e3f1e165 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Wed, 11 Mar 2026 22:44:29 -0700 Subject: [PATCH 15/23] Trigger CI rebuild (CodeQL) From d5d1e583e0cc6bae43af507f282cf5e6121ca627 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Fri, 13 Mar 2026 16:37:58 -0700 Subject: [PATCH 16/23] Replace ORT-based cast chain elimination with onnxscript optimizer Address PR review feedback from @devang-ml and @justinchuby: use onnxscript.optimizer.optimize() instead of ORT InferenceSession with session.enable_cast_chain_elimination to eliminate redundant Cast chains. - Remove onnxruntime dependency from cast_chain_elimination pass - Use onnxscript.optimizer.optimize() with TypeInferenceError fallback (same pattern as OnnxPeepholeOptimizer) - Update test comment to reflect onnxscript optimizer - Verified: numerically identical outputs (0.00 max abs diff) - Verified: no eval regression (69% on AI2D 100 samples) --- olive/passes/onnx/cast_chain_elimination.py | 55 +++++++------------ .../onnx/test_cast_chain_elimination.py | 5 +- 2 files changed, 23 insertions(+), 37 deletions(-) diff --git a/olive/passes/onnx/cast_chain_elimination.py b/olive/passes/onnx/cast_chain_elimination.py index 4fa2368aeb..00fd88ba18 100644 --- a/olive/passes/onnx/cast_chain_elimination.py +++ b/olive/passes/onnx/cast_chain_elimination.py @@ -2,12 +2,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""ORT-based Cast chain elimination and com.microsoft opset fixup. +"""Cast chain elimination and com.microsoft opset fixup. Dynamo-exported ONNX models often contain redundant Cast chains (e.g. fp32→fp16→fp32) that double memory traffic and slow inference. -ORT has a graph optimization for this, but it is disabled by default -behind the ``session.enable_cast_chain_elimination`` session config. Additionally, Olive ``GraphSurgeries`` may insert ``com.microsoft`` operators (such as ``LoopMHA``) without registering the custom opset @@ -16,8 +14,8 @@ This pass: 1. Ensures ``com.microsoft`` opset version 1 is declared on the model *and* on every ONNX function scope. -2. Runs ORT ``ORT_ENABLE_BASIC`` optimization with Cast chain - elimination explicitly enabled to produce a cleaned-up model. +2. Runs the ``onnxscript`` optimizer to fold or remove redundant Cast + chains and other constant-foldable patterns. """ import logging @@ -53,9 +51,9 @@ class OnnxCastChainElimination(Pass): This pass first ensures the ``com.microsoft`` opset version 1 is registered on the model graph and every ONNX function scope (needed after ``GraphSurgeries`` inserts custom ops into dynamo-exported - models). It then runs ORT basic graph optimization with the - ``session.enable_cast_chain_elimination`` flag enabled to collapse - consecutive Cast operators that cancel out. + models). It then runs the ``onnxscript`` optimizer to collapse + consecutive Cast operators that cancel out and perform other + peephole optimizations. """ @classmethod @@ -69,7 +67,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "enable_cast_chain_elimination": PassConfigParam( type_=bool, default_value=True, - description="Run ORT basic optimization with Cast chain elimination.", + description="Run onnxscript optimizer to eliminate redundant Cast chains.", ), **get_external_data_config(), } @@ -85,31 +83,20 @@ def _run_for_config( if config.fix_opset: _ensure_com_microsoft_opset(onnx_model) - # Step 2: Cast chain elimination via ORT session optimization + # Step 2: Cast chain elimination via onnxscript optimizer if config.enable_cast_chain_elimination: - import tempfile - - import onnxruntime as ort - - # ORT needs the patched model on disk to optimise it. - # Large models (>2 GB) must use external data to avoid protobuf limits. - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = str(Path(tmp_dir) / "model.onnx") - onnx.save_model( - onnx_model, - tmp_path, - save_as_external_data=True, - all_tensors_to_one_file=True, - location="model.onnx.data", - convert_attribute=True, - ) - - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC - sess_options.optimized_model_filepath = tmp_path - sess_options.add_session_config_entry("session.enable_cast_chain_elimination", "1") - ort.InferenceSession(tmp_path, sess_options, providers=["CPUExecutionProvider"]) - - onnx_model = onnx.load(tmp_path, load_external_data=True) + import onnxscript + + try: + onnx_model = onnxscript.optimizer.optimize(onnx_model) + except Exception as e: + if "TypeInferenceError" in str(e): + logger.info( + "onnxscript optimizer failed with %s. Rerunning with shape inference disabled.", + str(e), + ) + onnx_model = onnxscript.optimizer.optimize(onnx_model, onnx_shape_inference=False) + else: + raise return model_proto_to_olive_model(onnx_model, output_model_path, config) diff --git a/test/passes/onnx/test_cast_chain_elimination.py b/test/passes/onnx/test_cast_chain_elimination.py index 25ae896d8f..5c5a6895cd 100644 --- a/test/passes/onnx/test_cast_chain_elimination.py +++ b/test/passes/onnx/test_cast_chain_elimination.py @@ -102,9 +102,8 @@ def test_eliminates_redundant_cast_chain(self, cast_chain_model_path, tmp_path): output = p.run(olive_model, str(tmp_path / "out.onnx")) result_model = onnx.load(output.model_path) - # The pass should produce a valid, runnable model. - # Actual cast elimination depends on the ORT version; at minimum the - # output graph must not have *more* nodes than the input. + # The onnxscript optimizer should fold the redundant fp32→fp16→fp32 + # chain into an identity (0 nodes) or at most leave the original 2. assert len(result_model.graph.node) <= 2 def test_opset_fixup_applied(self, cast_chain_model_path, tmp_path): From 15975c8913474a172fa515e045e1ca627fc22ab1 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Sun, 15 Mar 2026 22:47:14 -0700 Subject: [PATCH 17/23] Replace onnxscript optimizer with targeted rewrite rule for cast chain elimination Use a custom CastCastRoundTrip rewrite rule instead of the full onnxscript.optimizer.optimize() call. The rewrite rule specifically targets round-trip Cast chains (e.g. fp32->fp16->fp32) by checking that the final cast type matches the original input type, and replaces them with Identity. This is simpler, faster, and avoids the TypeInferenceError fallback that was needed with the full optimizer. The onnxscript rewrite() function also runs RemoveUnusedNodesPass and RemoveUnusedOpsetsPass automatically. Validated: weights identical, 0.00 max abs diff, eval 69% unchanged. --- olive/passes/onnx/cast_chain_elimination.py | 68 +++++++++++++------ .../onnx/test_cast_chain_elimination.py | 7 +- 2 files changed, 52 insertions(+), 23 deletions(-) diff --git a/olive/passes/onnx/cast_chain_elimination.py b/olive/passes/onnx/cast_chain_elimination.py index 00fd88ba18..7fd1a6457c 100644 --- a/olive/passes/onnx/cast_chain_elimination.py +++ b/olive/passes/onnx/cast_chain_elimination.py @@ -14,8 +14,8 @@ This pass: 1. Ensures ``com.microsoft`` opset version 1 is declared on the model *and* on every ONNX function scope. -2. Runs the ``onnxscript`` optimizer to fold or remove redundant Cast - chains and other constant-foldable patterns. +2. Applies targeted onnxscript rewrite rules to eliminate redundant + round-trip Cast chains (e.g. fp32→fp16→fp32 → identity). """ import logging @@ -45,15 +45,52 @@ def _ensure_com_microsoft_opset(model: onnx.ModelProto): func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) +def _get_cast_chain_rewrite_rules(): + """Build onnxscript rewrite rules for eliminating redundant Cast chains. + + Returns a list of ``RewriteRule`` instances that target round-trip + Cast patterns (e.g. fp32→fp16→fp32) produced by dynamo export. + """ + from onnxscript import ir + from onnxscript.rewriter import RewriteRuleClassBase + from onnxscript.rewriter._basics import MatchResult + + class _CastCastRoundTrip(RewriteRuleClassBase): + """Collapse ``Cast(Cast(x, to=T2), to=T3)`` to ``Identity(x)`` when T3 matches x's type. + + Dynamo-exported models frequently insert unnecessary cast round-trips + (e.g. fp32→fp16→fp32). When the final cast type equals the original + input type the entire chain is a no-op and can be replaced by Identity. + """ + + def pattern(self, op, x, to, to_ignored): + return op.Cast(op.Cast(x, to=to_ignored), to=to) + + def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: + check_result = MatchResult() + if x.dtype is None: + return check_result.fail("Input dtype unknown; cannot verify round-trip") + if x.dtype != to.as_int(): + return check_result.fail( + f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}" + ) + return check_result + + def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): + return op.Identity(x) + + return [_CastCastRoundTrip().rule()] + + class OnnxCastChainElimination(Pass): """Fix com.microsoft opset declarations and eliminate redundant Cast chains. This pass first ensures the ``com.microsoft`` opset version 1 is registered on the model graph and every ONNX function scope (needed after ``GraphSurgeries`` inserts custom ops into dynamo-exported - models). It then runs the ``onnxscript`` optimizer to collapse - consecutive Cast operators that cancel out and perform other - peephole optimizations. + models). It then applies targeted onnxscript rewrite rules to + collapse consecutive Cast operators that form a round-trip + (e.g. fp32→fp16→fp32 → identity). """ @classmethod @@ -67,7 +104,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassCon "enable_cast_chain_elimination": PassConfigParam( type_=bool, default_value=True, - description="Run onnxscript optimizer to eliminate redundant Cast chains.", + description="Apply rewrite rules to eliminate redundant round-trip Cast chains.", ), **get_external_data_config(), } @@ -83,20 +120,11 @@ def _run_for_config( if config.fix_opset: _ensure_com_microsoft_opset(onnx_model) - # Step 2: Cast chain elimination via onnxscript optimizer + # Step 2: Cast chain elimination via targeted onnxscript rewrite rules if config.enable_cast_chain_elimination: - import onnxscript - - try: - onnx_model = onnxscript.optimizer.optimize(onnx_model) - except Exception as e: - if "TypeInferenceError" in str(e): - logger.info( - "onnxscript optimizer failed with %s. Rerunning with shape inference disabled.", - str(e), - ) - onnx_model = onnxscript.optimizer.optimize(onnx_model, onnx_shape_inference=False) - else: - raise + from onnxscript.rewriter import rewrite + + rules = _get_cast_chain_rewrite_rules() + onnx_model = rewrite(onnx_model, pattern_rewrite_rules=rules) return model_proto_to_olive_model(onnx_model, output_model_path, config) diff --git a/test/passes/onnx/test_cast_chain_elimination.py b/test/passes/onnx/test_cast_chain_elimination.py index 5c5a6895cd..d8182e9fb8 100644 --- a/test/passes/onnx/test_cast_chain_elimination.py +++ b/test/passes/onnx/test_cast_chain_elimination.py @@ -102,9 +102,10 @@ def test_eliminates_redundant_cast_chain(self, cast_chain_model_path, tmp_path): output = p.run(olive_model, str(tmp_path / "out.onnx")) result_model = onnx.load(output.model_path) - # The onnxscript optimizer should fold the redundant fp32→fp16→fp32 - # chain into an identity (0 nodes) or at most leave the original 2. - assert len(result_model.graph.node) <= 2 + # The rewrite rule collapses the round-trip fp32→fp16→fp32 chain + # into a single Identity node. + assert len(result_model.graph.node) == 1 + assert result_model.graph.node[0].op_type == "Identity" def test_opset_fixup_applied(self, cast_chain_model_path, tmp_path): olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) From 4ecba49d321328d1f35d53d4f559aaf15c96cb2f Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Mon, 16 Mar 2026 12:59:45 -0700 Subject: [PATCH 18/23] Fix lint: move onnxscript imports to top level (PLC0415) --- olive/passes/onnx/cast_chain_elimination.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/olive/passes/onnx/cast_chain_elimination.py b/olive/passes/onnx/cast_chain_elimination.py index 7fd1a6457c..a3758f1980 100644 --- a/olive/passes/onnx/cast_chain_elimination.py +++ b/olive/passes/onnx/cast_chain_elimination.py @@ -23,6 +23,9 @@ import onnx from onnx import helper +from onnxscript import ir +from onnxscript.rewriter import RewriteRuleClassBase, rewrite +from onnxscript.rewriter._basics import MatchResult from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler @@ -51,9 +54,6 @@ def _get_cast_chain_rewrite_rules(): Returns a list of ``RewriteRule`` instances that target round-trip Cast patterns (e.g. fp32→fp16→fp32) produced by dynamo export. """ - from onnxscript import ir - from onnxscript.rewriter import RewriteRuleClassBase - from onnxscript.rewriter._basics import MatchResult class _CastCastRoundTrip(RewriteRuleClassBase): """Collapse ``Cast(Cast(x, to=T2), to=T3)`` to ``Identity(x)`` when T3 matches x's type. @@ -71,9 +71,7 @@ def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> Match if x.dtype is None: return check_result.fail("Input dtype unknown; cannot verify round-trip") if x.dtype != to.as_int(): - return check_result.fail( - f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}" - ) + return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") return check_result def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): @@ -122,8 +120,6 @@ def _run_for_config( # Step 2: Cast chain elimination via targeted onnxscript rewrite rules if config.enable_cast_chain_elimination: - from onnxscript.rewriter import rewrite - rules = _get_cast_chain_rewrite_rules() onnx_model = rewrite(onnx_model, pattern_rewrite_rules=rules) From 054bd7cbce668e3316f09158c4e31513102863fa Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Mon, 16 Mar 2026 16:07:17 -0700 Subject: [PATCH 19/23] Fix lint: use functional RewriteRule API to avoid pylint W0221 (arguments-differ) --- olive/passes/onnx/cast_chain_elimination.py | 50 +++++++++++---------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/olive/passes/onnx/cast_chain_elimination.py b/olive/passes/onnx/cast_chain_elimination.py index a3758f1980..32451c973c 100644 --- a/olive/passes/onnx/cast_chain_elimination.py +++ b/olive/passes/onnx/cast_chain_elimination.py @@ -24,7 +24,7 @@ import onnx from onnx import helper from onnxscript import ir -from onnxscript.rewriter import RewriteRuleClassBase, rewrite +from onnxscript.rewriter import RewriteRule, rewrite from onnxscript.rewriter._basics import MatchResult from olive.hardware.accelerator import AcceleratorSpec @@ -55,29 +55,31 @@ def _get_cast_chain_rewrite_rules(): Cast patterns (e.g. fp32→fp16→fp32) produced by dynamo export. """ - class _CastCastRoundTrip(RewriteRuleClassBase): - """Collapse ``Cast(Cast(x, to=T2), to=T3)`` to ``Identity(x)`` when T3 matches x's type. - - Dynamo-exported models frequently insert unnecessary cast round-trips - (e.g. fp32→fp16→fp32). When the final cast type equals the original - input type the entire chain is a no-op and can be replaced by Identity. - """ - - def pattern(self, op, x, to, to_ignored): - return op.Cast(op.Cast(x, to=to_ignored), to=to) - - def check(self, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: - check_result = MatchResult() - if x.dtype is None: - return check_result.fail("Input dtype unknown; cannot verify round-trip") - if x.dtype != to.as_int(): - return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") - return check_result - - def rewrite(self, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr): - return op.Identity(x) - - return [_CastCastRoundTrip().rule()] + def _cast_cast_round_trip_pattern(op, x, to, to_ignored): + """Match ``Cast(Cast(x, to=T2), to=T3)``.""" + return op.Cast(op.Cast(x, to=to_ignored), to=to) + + def _cast_cast_round_trip_check(context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: + """Only match when the final cast type equals the original input type (round-trip).""" + check_result = MatchResult() + if x.dtype is None: + return check_result.fail("Input dtype unknown; cannot verify round-trip") + if x.dtype != to.as_int(): + return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") + return check_result + + def _cast_cast_round_trip_replacement(op, x, **_): + """Replace the round-trip cast chain with Identity.""" + return op.Identity(x) + + return [ + RewriteRule( + _cast_cast_round_trip_pattern, + _cast_cast_round_trip_replacement, + _cast_cast_round_trip_check, + name="CastCastRoundTrip", + ) + ] class OnnxCastChainElimination(Pass): From 9c54059322c69e64afbd09481c52e3c36f162294 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Mon, 16 Mar 2026 17:06:51 -0700 Subject: [PATCH 20/23] Merge cast chain elimination into OnnxPeepholeOptimizer Move _ensure_com_microsoft_opset and eliminate_cast_chains into ModelOptimizer class. Add fix_com_microsoft_opset and cast_chain_elimination config flags to OnnxPeepholeOptimizer. Remove standalone OnnxCastChainElimination pass, its olive_config entry, and its test file. Move tests into test_peephole_optimizer.py. Per devang-ml's review: consolidate into existing pass to avoid introducing a new one. --- olive/olive_config.json | 8 - olive/passes/onnx/cast_chain_elimination.py | 128 ---------------- olive/passes/onnx/peephole_optimizer.py | 102 ++++++++++++- .../onnx/test_cast_chain_elimination.py | 135 ---------------- test/passes/onnx/test_peephole_optimizer.py | 144 +++++++++++++++++- 5 files changed, 243 insertions(+), 274 deletions(-) delete mode 100644 olive/passes/onnx/cast_chain_elimination.py delete mode 100644 test/passes/onnx/test_cast_chain_elimination.py diff --git a/olive/olive_config.json b/olive/olive_config.json index a41eccc993..2748c39101 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -315,14 +315,6 @@ "supported_quantization_encodings": [ ], "extra_dependencies": [ "onnxoptimizer", "onnxscript" ] }, - "OnnxCastChainElimination": { - "module_path": "olive.passes.onnx.cast_chain_elimination.OnnxCastChainElimination", - "supported_providers": [ "*" ], - "supported_accelerators": [ "*" ], - "supported_precisions": [ "*" ], - "supported_algorithms": [ ], - "supported_quantization_encodings": [ ] - }, "OnnxQuantization": { "module_path": "olive.passes.onnx.quantization.OnnxQuantization", "supported_providers": [ "CPUExecutionProvider" ], diff --git a/olive/passes/onnx/cast_chain_elimination.py b/olive/passes/onnx/cast_chain_elimination.py deleted file mode 100644 index 32451c973c..0000000000 --- a/olive/passes/onnx/cast_chain_elimination.py +++ /dev/null @@ -1,128 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Cast chain elimination and com.microsoft opset fixup. - -Dynamo-exported ONNX models often contain redundant Cast chains -(e.g. fp32→fp16→fp32) that double memory traffic and slow inference. - -Additionally, Olive ``GraphSurgeries`` may insert ``com.microsoft`` -operators (such as ``LoopMHA``) without registering the custom opset -on every ONNX function scope, causing downstream failures. - -This pass: -1. Ensures ``com.microsoft`` opset version 1 is declared on the model - *and* on every ONNX function scope. -2. Applies targeted onnxscript rewrite rules to eliminate redundant - round-trip Cast chains (e.g. fp32→fp16→fp32 → identity). -""" - -import logging -from pathlib import Path - -import onnx -from onnx import helper -from onnxscript import ir -from onnxscript.rewriter import RewriteRule, rewrite -from onnxscript.rewriter._basics import MatchResult - -from olive.hardware.accelerator import AcceleratorSpec -from olive.model import ONNXModelHandler -from olive.model.utils import resolve_onnx_path -from olive.passes import Pass -from olive.passes.onnx.common import get_external_data_config, model_proto_to_olive_model -from olive.passes.pass_config import BasePassConfig, PassConfigParam - -logger = logging.getLogger(__name__) - - -def _ensure_com_microsoft_opset(model: onnx.ModelProto): - """Ensure com.microsoft opset v1 is declared at model and function level.""" - existing = {op.domain for op in model.opset_import} - if "com.microsoft" not in existing: - model.opset_import.append(helper.make_opsetid("com.microsoft", 1)) - for func in model.functions: - func_domains = {op.domain for op in func.opset_import} - if "com.microsoft" not in func_domains: - func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) - - -def _get_cast_chain_rewrite_rules(): - """Build onnxscript rewrite rules for eliminating redundant Cast chains. - - Returns a list of ``RewriteRule`` instances that target round-trip - Cast patterns (e.g. fp32→fp16→fp32) produced by dynamo export. - """ - - def _cast_cast_round_trip_pattern(op, x, to, to_ignored): - """Match ``Cast(Cast(x, to=T2), to=T3)``.""" - return op.Cast(op.Cast(x, to=to_ignored), to=to) - - def _cast_cast_round_trip_check(context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: - """Only match when the final cast type equals the original input type (round-trip).""" - check_result = MatchResult() - if x.dtype is None: - return check_result.fail("Input dtype unknown; cannot verify round-trip") - if x.dtype != to.as_int(): - return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") - return check_result - - def _cast_cast_round_trip_replacement(op, x, **_): - """Replace the round-trip cast chain with Identity.""" - return op.Identity(x) - - return [ - RewriteRule( - _cast_cast_round_trip_pattern, - _cast_cast_round_trip_replacement, - _cast_cast_round_trip_check, - name="CastCastRoundTrip", - ) - ] - - -class OnnxCastChainElimination(Pass): - """Fix com.microsoft opset declarations and eliminate redundant Cast chains. - - This pass first ensures the ``com.microsoft`` opset version 1 is - registered on the model graph and every ONNX function scope (needed - after ``GraphSurgeries`` inserts custom ops into dynamo-exported - models). It then applies targeted onnxscript rewrite rules to - collapse consecutive Cast operators that form a round-trip - (e.g. fp32→fp16→fp32 → identity). - """ - - @classmethod - def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: - return { - "fix_opset": PassConfigParam( - type_=bool, - default_value=True, - description="Ensure com.microsoft opset v1 is declared on all scopes.", - ), - "enable_cast_chain_elimination": PassConfigParam( - type_=bool, - default_value=True, - description="Apply rewrite rules to eliminate redundant round-trip Cast chains.", - ), - **get_external_data_config(), - } - - def _run_for_config( - self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str - ) -> ONNXModelHandler: - output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - - onnx_model = model.load_model() - - # Step 1: Opset fixup - if config.fix_opset: - _ensure_com_microsoft_opset(onnx_model) - - # Step 2: Cast chain elimination via targeted onnxscript rewrite rules - if config.enable_cast_chain_elimination: - rules = _get_cast_chain_rewrite_rules() - onnx_model = rewrite(onnx_model, pattern_rewrite_rules=rules) - - return model_proto_to_olive_model(onnx_model, output_model_path, config) diff --git a/olive/passes/onnx/peephole_optimizer.py b/olive/passes/onnx/peephole_optimizer.py index 47bdf097b0..8fe88abdd2 100644 --- a/olive/passes/onnx/peephole_optimizer.py +++ b/olive/passes/onnx/peephole_optimizer.py @@ -7,6 +7,10 @@ import numpy as np import onnx +from onnx import helper +from onnxscript import ir +from onnxscript.rewriter import RewriteRule, rewrite +from onnxscript.rewriter._basics import MatchResult from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler @@ -18,6 +22,40 @@ logger = logging.getLogger(__name__) +def _get_cast_chain_rewrite_rules(): + """Build onnxscript rewrite rules for eliminating redundant Cast chains. + + Returns a list of ``RewriteRule`` instances that target round-trip + Cast patterns (e.g. fp32→fp16→fp32) produced by dynamo export. + """ + + def _cast_cast_round_trip_pattern(op, x, to, to_ignored): + """Match ``Cast(Cast(x, to=T2), to=T3)``.""" + return op.Cast(op.Cast(x, to=to_ignored), to=to) + + def _cast_cast_round_trip_check(context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: + """Only match when the final cast type equals the original input type (round-trip).""" + check_result = MatchResult() + if x.dtype is None: + return check_result.fail("Input dtype unknown; cannot verify round-trip") + if x.dtype != to.as_int(): + return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") + return check_result + + def _cast_cast_round_trip_replacement(op, x, **_): + """Replace the round-trip cast chain with Identity.""" + return op.Identity(x) + + return [ + RewriteRule( + _cast_cast_round_trip_pattern, + _cast_cast_round_trip_replacement, + _cast_cast_round_trip_check, + name="CastCastRoundTrip", + ) + ] + + # TODO(anyone): Move from onnxruntime.transformers.onnx_model.OnnxModel to OnnxDAG # or reimplement logic using onnx-rewriter # no need to create a new instance of OnnxModel for each optimization @@ -26,6 +64,32 @@ def __init__(self, source_model_path): self.source_model_path = str(source_model_path) self.model = onnx.load(self.source_model_path) + def ensure_com_microsoft_opset(self): + """Ensure com.microsoft opset v1 is declared at model and function level. + + Olive ``GraphSurgeries`` may insert ``com.microsoft`` operators (such as + ``LoopMHA``) without registering the custom opset on every ONNX function + scope. This method fixes the declarations so that downstream passes and + validators do not fail. + """ + existing = {op.domain for op in self.model.opset_import} + if "com.microsoft" not in existing: + self.model.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + for func in self.model.functions: + func_domains = {op.domain for op in func.opset_import} + if "com.microsoft" not in func_domains: + func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + + def eliminate_cast_chains(self): + """Eliminate redundant round-trip Cast chains (e.g. fp32→fp16→fp32). + + Dynamo-exported ONNX models often contain unnecessary cast round-trips. + This method applies a targeted onnxscript rewrite rule to collapse them + into Identity nodes. + """ + rules = _get_cast_chain_rewrite_rules() + self.model = rewrite(self.model, pattern_rewrite_rules=rules) + def fuse_reshape_operations(self): # Remove unnecessary Reshape operator. Consecutive Reshape operators with latter's input being "[-1]" # i.e. flatten the input, the former Reshape operator is useless.""" @@ -85,11 +149,37 @@ def onnxoptimizer_optimize(self): class OnnxPeepholeOptimizer(Pass): - """Optimize ONNX model by fusing nodes.""" + """Optimize ONNX model by fusing nodes. + + Runs a combination of onnxscript optimizer, onnxoptimizer, reshape + fusion, and optionally: + - ``com.microsoft`` opset fixup (for models that use custom ops in + function scopes after ``GraphSurgeries``). + - Cast chain elimination (collapses round-trip Cast chains like + fp32→fp16→fp32 produced by dynamo export). + """ @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: - return get_external_data_config() + return { + "fix_com_microsoft_opset": PassConfigParam( + type_=bool, + default_value=False, + description=( + "Ensure com.microsoft opset v1 is declared on the model and all function scopes. " + "Enable this when GraphSurgeries inserts custom ops (e.g. LoopMHA) into function scopes." + ), + ), + "cast_chain_elimination": PassConfigParam( + type_=bool, + default_value=False, + description=( + "Apply a targeted rewrite rule to eliminate redundant round-trip Cast chains " + "(e.g. fp32→fp16→fp32 → identity) produced by dynamo export." + ), + ), + **get_external_data_config(), + } def _run_for_config( self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str @@ -102,5 +192,13 @@ def _run_for_config( peephole_optimizer.onnxoptimizer_optimize() peephole_optimizer.fuse_reshape_operations() + # Optional: fix com.microsoft opset declarations + if config.fix_com_microsoft_opset: + peephole_optimizer.ensure_com_microsoft_opset() + + # Optional: eliminate redundant round-trip Cast chains + if config.cast_chain_elimination: + peephole_optimizer.eliminate_cast_chains() + # save the model to the output path and return the model return model_proto_to_olive_model(peephole_optimizer.model, output_model_path, config) diff --git a/test/passes/onnx/test_cast_chain_elimination.py b/test/passes/onnx/test_cast_chain_elimination.py deleted file mode 100644 index d8182e9fb8..0000000000 --- a/test/passes/onnx/test_cast_chain_elimination.py +++ /dev/null @@ -1,135 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Tests for OnnxCastChainElimination pass.""" - -import numpy as np -import onnx -import onnxruntime as ort -import pytest -from onnx import TensorProto, helper - -from olive.model import ONNXModelHandler -from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx.cast_chain_elimination import OnnxCastChainElimination, _ensure_com_microsoft_opset - - -class TestEnsureComMicrosoftOpset: - """Unit tests for _ensure_com_microsoft_opset helper.""" - - def test_adds_opset_when_missing(self): - model = helper.make_model( - helper.make_graph([], "g", [], []), - opset_imports=[helper.make_opsetid("", 17)], - ) - assert not any(op.domain == "com.microsoft" for op in model.opset_import) - - _ensure_com_microsoft_opset(model) - - assert any(op.domain == "com.microsoft" for op in model.opset_import) - - def test_does_not_duplicate_opset(self): - model = helper.make_model( - helper.make_graph([], "g", [], []), - opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("com.microsoft", 1)], - ) - - _ensure_com_microsoft_opset(model) - - ms_opsets = [op for op in model.opset_import if op.domain == "com.microsoft"] - assert len(ms_opsets) == 1 - - def test_adds_opset_to_functions(self): - func = onnx.FunctionProto() - func.name = "test_func" - func.domain = "test.domain" - func.opset_import.append(helper.make_opsetid("", 17)) - - model = helper.make_model( - helper.make_graph([], "g", [], []), - opset_imports=[helper.make_opsetid("", 17)], - ) - model.functions.append(func) - - _ensure_com_microsoft_opset(model) - - func_domains = {op.domain for op in model.functions[0].opset_import} - assert "com.microsoft" in func_domains - - def test_skips_function_with_existing_opset(self): - func = onnx.FunctionProto() - func.name = "test_func" - func.domain = "test.domain" - func.opset_import.append(helper.make_opsetid("", 17)) - func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) - - model = helper.make_model( - helper.make_graph([], "g", [], []), - opset_imports=[helper.make_opsetid("", 17)], - ) - model.functions.append(func) - - _ensure_com_microsoft_opset(model) - - ms_opsets = [op for op in model.functions[0].opset_import if op.domain == "com.microsoft"] - assert len(ms_opsets) == 1 - - -class TestOnnxCastChainElimination: - """Integration tests for OnnxCastChainElimination pass.""" - - @pytest.fixture - def cast_chain_model_path(self, tmp_path): - """Model with a redundant Cast chain: fp32 → fp16 → fp32.""" - input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) - output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) - - cast_to_fp16 = helper.make_node("Cast", ["X"], ["x_fp16"], to=TensorProto.FLOAT16) - cast_back = helper.make_node("Cast", ["x_fp16"], ["Y"], to=TensorProto.FLOAT) - - graph = helper.make_graph([cast_to_fp16, cast_back], "cast_chain", [input_tensor], [output_tensor]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) - onnx.checker.check_model(model) - - path = tmp_path / "cast_chain.onnx" - onnx.save(model, str(path)) - return path - - def test_eliminates_redundant_cast_chain(self, cast_chain_model_path, tmp_path): - olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) - p = create_pass_from_dict(OnnxCastChainElimination, {}, disable_search=True) - output = p.run(olive_model, str(tmp_path / "out.onnx")) - - result_model = onnx.load(output.model_path) - # The rewrite rule collapses the round-trip fp32→fp16→fp32 chain - # into a single Identity node. - assert len(result_model.graph.node) == 1 - assert result_model.graph.node[0].op_type == "Identity" - - def test_opset_fixup_applied(self, cast_chain_model_path, tmp_path): - olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) - p = create_pass_from_dict( - OnnxCastChainElimination, - {"fix_opset": True, "enable_cast_chain_elimination": False}, - disable_search=True, - ) - output = p.run(olive_model, str(tmp_path / "out.onnx")) - - result_model = onnx.load(output.model_path) - assert any(op.domain == "com.microsoft" for op in result_model.opset_import) - - def test_numerical_correctness(self, cast_chain_model_path, tmp_path): - """Optimized model should produce the same output as the original.""" - olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) - p = create_pass_from_dict(OnnxCastChainElimination, {}, disable_search=True) - output = p.run(olive_model, str(tmp_path / "out.onnx")) - - x = np.random.randn(2, 4).astype(np.float32) - - orig_sess = ort.InferenceSession(str(cast_chain_model_path), providers=["CPUExecutionProvider"]) - opt_sess = ort.InferenceSession(output.model_path, providers=["CPUExecutionProvider"]) - - orig_out = orig_sess.run(None, {"X": x})[0] - opt_out = opt_sess.run(None, {"X": x})[0] - np.testing.assert_allclose(orig_out, opt_out, atol=1e-3) diff --git a/test/passes/onnx/test_peephole_optimizer.py b/test/passes/onnx/test_peephole_optimizer.py index 745289a455..8bef54305b 100644 --- a/test/passes/onnx/test_peephole_optimizer.py +++ b/test/passes/onnx/test_peephole_optimizer.py @@ -5,13 +5,17 @@ from pathlib import Path from unittest.mock import patch +import numpy as np +import onnx +import onnxruntime as ort import pytest from onnx import TensorProto, helper from olive.hardware import DEFAULT_CPU_ACCELERATOR +from olive.model import ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict from olive.passes.onnx.common import model_proto_to_olive_model -from olive.passes.onnx.peephole_optimizer import OnnxPeepholeOptimizer +from olive.passes.onnx.peephole_optimizer import ModelOptimizer, OnnxPeepholeOptimizer from test.utils import get_onnx_model @@ -138,3 +142,141 @@ def test_onnxoptimizer(mock_onnxscript, mock_onnxoptimizer, mock_model_proto_to_ # assert mock_onnxoptimizer.assert_called_once() + + +# ── _ensure_com_microsoft_opset ──────────────────────────────────────────── + + +class TestEnsureComMicrosoftOpset: + """Unit tests for ModelOptimizer.ensure_com_microsoft_opset.""" + + def _make_optimizer_with_model(self, model, tmp_path): + """Save a model to disk and create a ModelOptimizer around it.""" + path = tmp_path / "model.onnx" + onnx.save(model, str(path)) + opt = ModelOptimizer(str(path)) + opt.model = model # use the in-memory model directly + return opt + + def test_adds_opset_when_missing(self, tmp_path): + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + assert not any(op.domain == "com.microsoft" for op in model.opset_import) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + assert any(op.domain == "com.microsoft" for op in opt.model.opset_import) + + def test_does_not_duplicate_opset(self, tmp_path): + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("com.microsoft", 1)], + ) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + ms_opsets = [op for op in opt.model.opset_import if op.domain == "com.microsoft"] + assert len(ms_opsets) == 1 + + def test_adds_opset_to_functions(self, tmp_path): + func = onnx.FunctionProto() + func.name = "test_func" + func.domain = "test.domain" + func.opset_import.append(helper.make_opsetid("", 17)) + + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + model.functions.append(func) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + func_domains = {op.domain for op in opt.model.functions[0].opset_import} + assert "com.microsoft" in func_domains + + def test_skips_function_with_existing_opset(self, tmp_path): + func = onnx.FunctionProto() + func.name = "test_func" + func.domain = "test.domain" + func.opset_import.append(helper.make_opsetid("", 17)) + func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + model.functions.append(func) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + ms_opsets = [op for op in opt.model.functions[0].opset_import if op.domain == "com.microsoft"] + assert len(ms_opsets) == 1 + + +# ── Cast chain elimination via OnnxPeepholeOptimizer ─────────────────────── + + +class TestCastChainElimination: + """Tests for cast chain elimination integrated into OnnxPeepholeOptimizer.""" + + @pytest.fixture + def cast_chain_model_path(self, tmp_path): + """Model with a redundant Cast chain: fp32 → fp16 → fp32.""" + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + cast_to_fp16 = helper.make_node("Cast", ["X"], ["x_fp16"], to=TensorProto.FLOAT16) + cast_back = helper.make_node("Cast", ["x_fp16"], ["Y"], to=TensorProto.FLOAT) + + graph = helper.make_graph([cast_to_fp16, cast_back], "cast_chain", [input_tensor], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + onnx.checker.check_model(model) + + path = tmp_path / "cast_chain.onnx" + onnx.save(model, str(path)) + return path + + def test_eliminates_redundant_cast_chain(self, cast_chain_model_path, tmp_path): + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict(OnnxPeepholeOptimizer, {"cast_chain_elimination": True}, disable_search=True) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + result_model = onnx.load(output.model_path) + # The rewrite rule collapses the round-trip fp32→fp16→fp32 chain + # into a single Identity node. + assert len(result_model.graph.node) == 1 + assert result_model.graph.node[0].op_type == "Identity" + + def test_opset_fixup_applied(self, cast_chain_model_path, tmp_path): + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict( + OnnxPeepholeOptimizer, + {"fix_com_microsoft_opset": True, "cast_chain_elimination": False}, + disable_search=True, + ) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + result_model = onnx.load(output.model_path) + assert any(op.domain == "com.microsoft" for op in result_model.opset_import) + + def test_numerical_correctness(self, cast_chain_model_path, tmp_path): + """Optimized model should produce the same output as the original.""" + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict(OnnxPeepholeOptimizer, {"cast_chain_elimination": True}, disable_search=True) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + x = np.random.randn(2, 4).astype(np.float32) + + orig_sess = ort.InferenceSession(str(cast_chain_model_path), providers=["CPUExecutionProvider"]) + opt_sess = ort.InferenceSession(output.model_path, providers=["CPUExecutionProvider"]) + + orig_out = orig_sess.run(None, {"X": x})[0] + opt_out = opt_sess.run(None, {"X": x})[0] + np.testing.assert_allclose(orig_out, opt_out, atol=1e-3) From 95784978f942d751df9b6f0e137b2f4566fa176e Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Mon, 16 Mar 2026 17:14:11 -0700 Subject: [PATCH 21/23] Move _get_cast_chain_rewrite_rules into ModelOptimizer as static method --- olive/passes/onnx/peephole_optimizer.py | 63 +++++++++++-------------- 1 file changed, 28 insertions(+), 35 deletions(-) diff --git a/olive/passes/onnx/peephole_optimizer.py b/olive/passes/onnx/peephole_optimizer.py index 8fe88abdd2..d9d9bc1fd8 100644 --- a/olive/passes/onnx/peephole_optimizer.py +++ b/olive/passes/onnx/peephole_optimizer.py @@ -22,40 +22,6 @@ logger = logging.getLogger(__name__) -def _get_cast_chain_rewrite_rules(): - """Build onnxscript rewrite rules for eliminating redundant Cast chains. - - Returns a list of ``RewriteRule`` instances that target round-trip - Cast patterns (e.g. fp32→fp16→fp32) produced by dynamo export. - """ - - def _cast_cast_round_trip_pattern(op, x, to, to_ignored): - """Match ``Cast(Cast(x, to=T2), to=T3)``.""" - return op.Cast(op.Cast(x, to=to_ignored), to=to) - - def _cast_cast_round_trip_check(context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: - """Only match when the final cast type equals the original input type (round-trip).""" - check_result = MatchResult() - if x.dtype is None: - return check_result.fail("Input dtype unknown; cannot verify round-trip") - if x.dtype != to.as_int(): - return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") - return check_result - - def _cast_cast_round_trip_replacement(op, x, **_): - """Replace the round-trip cast chain with Identity.""" - return op.Identity(x) - - return [ - RewriteRule( - _cast_cast_round_trip_pattern, - _cast_cast_round_trip_replacement, - _cast_cast_round_trip_check, - name="CastCastRoundTrip", - ) - ] - - # TODO(anyone): Move from onnxruntime.transformers.onnx_model.OnnxModel to OnnxDAG # or reimplement logic using onnx-rewriter # no need to create a new instance of OnnxModel for each optimization @@ -87,9 +53,36 @@ def eliminate_cast_chains(self): This method applies a targeted onnxscript rewrite rule to collapse them into Identity nodes. """ - rules = _get_cast_chain_rewrite_rules() + rules = self._get_cast_chain_rewrite_rules() self.model = rewrite(self.model, pattern_rewrite_rules=rules) + @staticmethod + def _get_cast_chain_rewrite_rules(): + """Build onnxscript rewrite rules for eliminating redundant Cast chains.""" + + def _cast_cast_round_trip_pattern(op, x, to, to_ignored): + return op.Cast(op.Cast(x, to=to_ignored), to=to) + + def _cast_cast_round_trip_check(context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: + check_result = MatchResult() + if x.dtype is None: + return check_result.fail("Input dtype unknown; cannot verify round-trip") + if x.dtype != to.as_int(): + return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") + return check_result + + def _cast_cast_round_trip_replacement(op, x, **_): + return op.Identity(x) + + return [ + RewriteRule( + _cast_cast_round_trip_pattern, + _cast_cast_round_trip_replacement, + _cast_cast_round_trip_check, + name="CastCastRoundTrip", + ) + ] + def fuse_reshape_operations(self): # Remove unnecessary Reshape operator. Consecutive Reshape operators with latter's input being "[-1]" # i.e. flatten the input, the former Reshape operator is useless.""" From 7a2e634681a416d9e1bf50aa86afe7976ac1d7db Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Mon, 16 Mar 2026 17:24:03 -0700 Subject: [PATCH 22/23] Make all ModelOptimizer steps configurable in OnnxPeepholeOptimizer Add onnxscript_optimize, onnxoptimizer_optimize, and fuse_reshape_operations config flags (default True for backward compatibility). This allows recipe configs to disable the default optimizations and only run opset fixup + cast chain elimination, producing byte-identical models to the old standalone pass. --- olive/passes/onnx/peephole_optimizer.py | 30 ++++++++++++++++++++----- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/olive/passes/onnx/peephole_optimizer.py b/olive/passes/onnx/peephole_optimizer.py index d9d9bc1fd8..e08a9a6451 100644 --- a/olive/passes/onnx/peephole_optimizer.py +++ b/olive/passes/onnx/peephole_optimizer.py @@ -155,6 +155,21 @@ class OnnxPeepholeOptimizer(Pass): @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: return { + "onnxscript_optimize": PassConfigParam( + type_=bool, + default_value=True, + description="Run onnxscript optimizer for general graph optimizations.", + ), + "onnxoptimizer_optimize": PassConfigParam( + type_=bool, + default_value=True, + description="Run onnxoptimizer for additional graph optimizations.", + ), + "fuse_reshape_operations": PassConfigParam( + type_=bool, + default_value=True, + description="Fuse consecutive Reshape operators where the latter flattens to [-1].", + ), "fix_com_microsoft_opset": PassConfigParam( type_=bool, default_value=False, @@ -179,17 +194,20 @@ def _run_for_config( ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - # optimize model peephole_optimizer = ModelOptimizer(model.model_path) - peephole_optimizer.onnxscript_optimize() - peephole_optimizer.onnxoptimizer_optimize() - peephole_optimizer.fuse_reshape_operations() - # Optional: fix com.microsoft opset declarations + if config.onnxscript_optimize: + peephole_optimizer.onnxscript_optimize() + + if config.onnxoptimizer_optimize: + peephole_optimizer.onnxoptimizer_optimize() + + if config.fuse_reshape_operations: + peephole_optimizer.fuse_reshape_operations() + if config.fix_com_microsoft_opset: peephole_optimizer.ensure_com_microsoft_opset() - # Optional: eliminate redundant round-trip Cast chains if config.cast_chain_elimination: peephole_optimizer.eliminate_cast_chains() From f50743d7d0a14997fe76371f921a58eeb5859aba Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Mon, 16 Mar 2026 17:42:53 -0700 Subject: [PATCH 23/23] Fix lint: remove duplicate numpy import in test (W0621/W0404) --- test/passes/onnx/test_peephole_optimizer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/passes/onnx/test_peephole_optimizer.py b/test/passes/onnx/test_peephole_optimizer.py index 8bef54305b..a124cd5e42 100644 --- a/test/passes/onnx/test_peephole_optimizer.py +++ b/test/passes/onnx/test_peephole_optimizer.py @@ -42,8 +42,6 @@ def test_onnx_peephole_optimizer_pass(tmp_path): # TODO(team): this test will creat an unnecessary intermediate model file. Need to optimize it. def test_onnx_peephole_optimizer_pass_fuse_reshape_operations(tmp_path, external_data_config): - import numpy as np - X = helper.make_tensor_value_info("X", TensorProto.INT64, [None]) # noqa: N806 Y = helper.make_tensor_value_info("Y", TensorProto.INT64, [None]) # noqa: N806