diff --git a/olive/passes/onnx/common.py b/olive/passes/onnx/common.py index 4c97be173..14daffa7c 100644 --- a/olive/passes/onnx/common.py +++ b/olive/passes/onnx/common.py @@ -12,7 +12,13 @@ import onnx from onnx import external_data_helper from onnxscript import ir -from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY + +# TODO(sunghcho): Remove try/except once onnxscript >= 0.2.0 (which exports FOLDED_FROM_KEY) is the minimum +# required version. After that, replace with: from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY +try: + from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY +except ImportError: + FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.constant_folding.folded_from" from olive.common.utils import StrEnumBase, hardlink_copy_file from olive.model import CompositeModelHandler, ONNXModelHandler diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 7d1c41119..e82f83ff1 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -1233,6 +1233,110 @@ def __call__(self, model: ModelProto): return dag.model +class GemmToMatMulAdd(ProtoSurgeon): + """Replace Gemm with MatMul (+ Add) for INT4 quantization compatibility. + + The INT4 RTN quantizer only recognizes MatMul nodes. This surgeon converts + Gemm nodes back to MatMul+Add so that the weight matrices become eligible + for block-wise quantization. + + Handles transB by transposing constant weights in-place or inserting a + Transpose node for non-constant weights. Skips Gemm nodes whose alpha/beta + are not 1.0 or whose transA is set. + """ + + def __call__(self, model: ModelProto): + from onnx import helper, numpy_helper + + graph = model.graph + initializer_map = {init.name: init for init in graph.initializer} + existing_names = ( + {init.name for init in graph.initializer} + | {vi.name for vi in graph.input} + | {vi.name for vi in graph.output} + | {vi.name for vi in graph.value_info} + ) + nodes_to_remove = [] + nodes_to_add = [] + gemm_rewrite_idx = 0 + + for node in graph.node: + if node.op_type != "Gemm": + continue + + alpha = beta = 1.0 + trans_a = trans_b = 0 + for attr in node.attribute: + if attr.name == "alpha": + alpha = attr.f + elif attr.name == "beta": + beta = attr.f + elif attr.name == "transA": + trans_a = attr.i + elif attr.name == "transB": + trans_b = attr.i + + if alpha != 1.0 or beta != 1.0 or trans_a != 0: + continue + + inp_a, inp_b = node.input[0], node.input[1] + inp_c = node.input[2] if len(node.input) > 2 else None + out_y = node.output[0] + + # Derive a stable base name for new nodes/tensors. + base_name = node.name or out_y or f"gemm_rewrite_{gemm_rewrite_idx}" + + if trans_b: + if inp_b in initializer_map: + # Create a new transposed initializer to avoid mutating + # a potentially shared initializer in-place. + init = initializer_map[inp_b] + w_t = numpy_helper.to_array(init).T.copy() + new_name = f"{inp_b}_transposed" + suffix = 0 + while new_name in existing_names: + suffix += 1 + new_name = f"{inp_b}_transposed_{suffix}" + new_init = numpy_helper.from_array(w_t, name=new_name) + graph.initializer.append(new_init) + initializer_map[new_name] = new_init + existing_names.add(new_name) + matmul_rhs = new_name + else: + transpose_out = f"{base_name}_transpose_B" + nodes_to_add.append( + helper.make_node( + "Transpose", [inp_b], [transpose_out], name=f"{base_name}_Transpose", perm=[1, 0] + ) + ) + matmul_rhs = transpose_out + else: + matmul_rhs = inp_b + + if inp_c: + matmul_out = f"{base_name}_matmul_out" + nodes_to_add.append( + helper.make_node("MatMul", [inp_a, matmul_rhs], [matmul_out], name=f"{base_name}_MatMul") + ) + nodes_to_add.append(helper.make_node("Add", [matmul_out, inp_c], [out_y], name=f"{base_name}_Add")) + else: + nodes_to_add.append( + helper.make_node("MatMul", [inp_a, matmul_rhs], [out_y], name=f"{base_name}_MatMul") + ) + + nodes_to_remove.append(node) + gemm_rewrite_idx += 1 + + for node in nodes_to_remove: + graph.node.remove(node) + graph.node.extend(nodes_to_add) + + if nodes_to_remove: + logger.debug("Replaced %d Gemm nodes with MatMul + Add nodes", len(nodes_to_remove)) + + return model + + class RemoveRopeMultiCache(ProtoSurgeon): """Remove the multi rope cache from the model.""" @@ -2041,6 +2145,174 @@ def equal_weights(self, dag: OnnxDAG, init0: str, init1: str, transpose: bool = return np.array_equal(arr0.ravel(), arr1.ravel()) +class ReciprocalMulToDiv(ProtoSurgeon): + """Replace Reciprocal(x) * a with Div(a, x). + + Before: + [x] --> Reciprocal --> Mul --> [out] + ^ + | + [a] + + After: + [a] --> Div --> [out] + ^ + | + [x] + + Why this is needed: + PyTorch's ``torch.rsqrt()`` (used by Qwen2.5-VL's ``Qwen2RMSNorm``) decomposes to + ``Sqrt -> Reciprocal -> Mul`` in ONNX. ORT's ``SimplifiedLayerNormFusion`` only + matches the pattern ``Pow -> ReduceMean -> Add -> Sqrt -> Div -> Mul`` — it does + **not** recognize the ``Reciprocal -> Mul`` variant (confirmed on ORT main as of + 2025-06). This pass canonicalizes the graph so that the fusion fires, replacing + decomposed RMSNorm with a single ``SimplifiedLayerNormalization`` op. + + When to use: + Run **before** ``OrtTransformersOptimization`` on models whose normalization layers + export ``rsqrt`` as ``Reciprocal`` (e.g. HuggingFace models using ``torch.rsqrt``). + """ + + def __call__(self, model: ModelProto): + from collections import defaultdict + + modified = 0 + nodes_to_remove = [] + + # Build a map from tensor name to consuming nodes (avoids O(N^2) scans). + consumers: dict[str, list] = defaultdict(list) + for n in model.graph.node: + for input_name in n.input: + if input_name: + consumers[input_name].append(n) + + for node in model.graph.node: + if node.op_type != "Reciprocal": + continue + + recip_input = node.input[0] # x + recip_output = node.output[0] + + # Find Mul consumers of this Reciprocal using the precomputed consumer map + mul_nodes = [n for n in consumers.get(recip_output, []) if n.op_type == "Mul"] + + for mul_node in mul_nodes: + # Identify the other operand (not from Reciprocal) + if mul_node.input[0] == recip_output: + other_input = mul_node.input[1] + else: + other_input = mul_node.input[0] + + # Convert Mul(a, Reciprocal(x)) to Div(a, x) in-place + mul_node.op_type = "Div" + mul_node.input[0] = other_input + mul_node.input[1] = recip_input + if mul_node.name: + mul_node.name = self.create_new_name(mul_node.name, "Mul", "Div") + modified += 1 + + # If no more consumers of Reciprocal output, mark for removal. + # Note: consumer map may be stale after in-place input rewrites, + # so re-check actual inputs of remaining consumers. + remaining = [n for n in consumers.get(recip_output, []) if n is not node and recip_output in list(n.input)] + if not remaining: + nodes_to_remove.append(node) + + for node in nodes_to_remove: + model.graph.node.remove(node) + + if modified > 0: + logger.debug("Replaced %d Reciprocal+Mul patterns with Div", modified) + + return model + + +class DeduplicateSubgraphInitializers(ProtoSurgeon): + """Remove duplicate initializers in Loop / If / Scan subgraphs. + + Why this is needed: + ORT's graph optimizer (constant folding, shape inference, etc.) may copy + initializers into subgraphs that already contain them, creating entries with + identical names. ORT's ``ConstantSharing`` pass explicitly skips subgraph + usage (``constant_sharing.cc``: "If usage is from subgraph, skip it now"), + so these duplicates are never cleaned up. Duplicate initializers violate + the ONNX spec's unique-name requirement and can cause validation failures + or silent data corruption. + + What it does: + For every ``Loop`` / ``If`` / ``Scan`` subgraph, keeps the first initializer + with a given name and removes all subsequent duplicates. + + When to use: + Run **after** ``OrtTransformersOptimization`` (which introduces the duplicates) + and **before** any pass that serializes or validates the model. + """ + + def __call__(self, model: ModelProto): + removed = 0 + for node in model.graph.node: + for attr in node.attribute: + if attr.g and attr.g.initializer: + seen = set() + to_remove = [] + for init in attr.g.initializer: + if init.name in seen: + to_remove.append(init) + else: + seen.add(init.name) + for init in to_remove: + attr.g.initializer.remove(init) + removed += 1 + if removed > 0: + logger.debug("Removed %d duplicate subgraph initializers", removed) + return model + + +class DeduplicateNodes(ProtoSurgeon): + """Remove nodes whose output tensors are already produced by an earlier node. + + Before (invalid — two nodes define the same tensor ``/Cast_output_0``): + NodeA --> Cast --> /Cast_output_0 + NodeB --> Cast --> /Cast_output_0 (duplicate, removed) + + After: + NodeA --> Cast --> /Cast_output_0 + + Why this is needed: + ORT's ``convert_float_to_float16`` (``float16.py``) may insert identical + ``Cast`` nodes in parallel branches that each declare the same output tensor + name. The ONNX spec requires every tensor to have a unique producer; loading + a model with duplicate producers causes ``onnxruntime.InferenceSession`` to + fail with a duplicate-definition error. + + What it does: + Scans nodes in graph order and records each output tensor name. If a later + node produces a tensor name that was already seen, the entire node is removed. + + When to use: + Run **after** ``OnnxFloatToFloat16`` as a cleanup step. + """ + + def __call__(self, model: ModelProto): + output_seen: set[str] = set() + indices_to_remove: list[int] = [] + for i, node in enumerate(model.graph.node): + dup = False + for o in node.output: + if o and o in output_seen: + dup = True + break + if o: + output_seen.add(o) + if dup: + indices_to_remove.append(i) + for i in reversed(indices_to_remove): + del model.graph.node[i] + if indices_to_remove: + logger.debug("Removed %d duplicate nodes", len(indices_to_remove)) + return model + + class PackedAttentionToLoopMHA(Surgeon): """Replace custom::PackedAttention with a loop calling com.microsoft::MultiHeadAttention. diff --git a/olive/passes/onnx/peephole_optimizer.py b/olive/passes/onnx/peephole_optimizer.py index 47bdf097b..e08a9a645 100644 --- a/olive/passes/onnx/peephole_optimizer.py +++ b/olive/passes/onnx/peephole_optimizer.py @@ -7,6 +7,10 @@ import numpy as np import onnx +from onnx import helper +from onnxscript import ir +from onnxscript.rewriter import RewriteRule, rewrite +from onnxscript.rewriter._basics import MatchResult from olive.hardware.accelerator import AcceleratorSpec from olive.model import ONNXModelHandler @@ -26,6 +30,59 @@ def __init__(self, source_model_path): self.source_model_path = str(source_model_path) self.model = onnx.load(self.source_model_path) + def ensure_com_microsoft_opset(self): + """Ensure com.microsoft opset v1 is declared at model and function level. + + Olive ``GraphSurgeries`` may insert ``com.microsoft`` operators (such as + ``LoopMHA``) without registering the custom opset on every ONNX function + scope. This method fixes the declarations so that downstream passes and + validators do not fail. + """ + existing = {op.domain for op in self.model.opset_import} + if "com.microsoft" not in existing: + self.model.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + for func in self.model.functions: + func_domains = {op.domain for op in func.opset_import} + if "com.microsoft" not in func_domains: + func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + + def eliminate_cast_chains(self): + """Eliminate redundant round-trip Cast chains (e.g. fp32→fp16→fp32). + + Dynamo-exported ONNX models often contain unnecessary cast round-trips. + This method applies a targeted onnxscript rewrite rule to collapse them + into Identity nodes. + """ + rules = self._get_cast_chain_rewrite_rules() + self.model = rewrite(self.model, pattern_rewrite_rules=rules) + + @staticmethod + def _get_cast_chain_rewrite_rules(): + """Build onnxscript rewrite rules for eliminating redundant Cast chains.""" + + def _cast_cast_round_trip_pattern(op, x, to, to_ignored): + return op.Cast(op.Cast(x, to=to_ignored), to=to) + + def _cast_cast_round_trip_check(context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> MatchResult: + check_result = MatchResult() + if x.dtype is None: + return check_result.fail("Input dtype unknown; cannot verify round-trip") + if x.dtype != to.as_int(): + return check_result.fail(f"Not a round-trip cast: input dtype {x.dtype} != final cast to={to.as_int()}") + return check_result + + def _cast_cast_round_trip_replacement(op, x, **_): + return op.Identity(x) + + return [ + RewriteRule( + _cast_cast_round_trip_pattern, + _cast_cast_round_trip_replacement, + _cast_cast_round_trip_check, + name="CastCastRoundTrip", + ) + ] + def fuse_reshape_operations(self): # Remove unnecessary Reshape operator. Consecutive Reshape operators with latter's input being "[-1]" # i.e. flatten the input, the former Reshape operator is useless.""" @@ -85,22 +142,74 @@ def onnxoptimizer_optimize(self): class OnnxPeepholeOptimizer(Pass): - """Optimize ONNX model by fusing nodes.""" + """Optimize ONNX model by fusing nodes. + + Runs a combination of onnxscript optimizer, onnxoptimizer, reshape + fusion, and optionally: + - ``com.microsoft`` opset fixup (for models that use custom ops in + function scopes after ``GraphSurgeries``). + - Cast chain elimination (collapses round-trip Cast chains like + fp32→fp16→fp32 produced by dynamo export). + """ @classmethod def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: - return get_external_data_config() + return { + "onnxscript_optimize": PassConfigParam( + type_=bool, + default_value=True, + description="Run onnxscript optimizer for general graph optimizations.", + ), + "onnxoptimizer_optimize": PassConfigParam( + type_=bool, + default_value=True, + description="Run onnxoptimizer for additional graph optimizations.", + ), + "fuse_reshape_operations": PassConfigParam( + type_=bool, + default_value=True, + description="Fuse consecutive Reshape operators where the latter flattens to [-1].", + ), + "fix_com_microsoft_opset": PassConfigParam( + type_=bool, + default_value=False, + description=( + "Ensure com.microsoft opset v1 is declared on the model and all function scopes. " + "Enable this when GraphSurgeries inserts custom ops (e.g. LoopMHA) into function scopes." + ), + ), + "cast_chain_elimination": PassConfigParam( + type_=bool, + default_value=False, + description=( + "Apply a targeted rewrite rule to eliminate redundant round-trip Cast chains " + "(e.g. fp32→fp16→fp32 → identity) produced by dynamo export." + ), + ), + **get_external_data_config(), + } def _run_for_config( self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str ) -> ONNXModelHandler: output_model_path = resolve_onnx_path(output_model_path, Path(model.model_path).name) - # optimize model peephole_optimizer = ModelOptimizer(model.model_path) - peephole_optimizer.onnxscript_optimize() - peephole_optimizer.onnxoptimizer_optimize() - peephole_optimizer.fuse_reshape_operations() + + if config.onnxscript_optimize: + peephole_optimizer.onnxscript_optimize() + + if config.onnxoptimizer_optimize: + peephole_optimizer.onnxoptimizer_optimize() + + if config.fuse_reshape_operations: + peephole_optimizer.fuse_reshape_operations() + + if config.fix_com_microsoft_opset: + peephole_optimizer.ensure_com_microsoft_opset() + + if config.cast_chain_elimination: + peephole_optimizer.eliminate_cast_chains() # save the model to the output path and return the model return model_proto_to_olive_model(peephole_optimizer.model, output_model_path, config) diff --git a/olive/passes/onnx/rtn_quantization.py b/olive/passes/onnx/rtn_quantization.py index d42bbcc22..d4742aeeb 100644 --- a/olive/passes/onnx/rtn_quantization.py +++ b/olive/passes/onnx/rtn_quantization.py @@ -104,6 +104,10 @@ def _quantize_model( nodes_to_exclude = nodes_to_exclude or [] nodes_to_include = nodes_to_include or [] + # Track initializer names already registered across all nodes + # to handle shared weights (e.g., pos_embed used by multiple Gather nodes). + globally_registered = {} + ir_model.graph.sort() for node in ir_model.graph.all_nodes(): node_name = node.name @@ -115,14 +119,18 @@ def _quantize_model( elif node.op_type in (str(OpType.MatMul), str(OpType.Gather)) and ( node_name in nodes_to_include or not nodes_to_include ): - if (node.op_type == OpType.MatMul and not node.inputs[1].is_initializer()) or ( - node.op_type == OpType.Gather and not node.inputs[0].is_initializer() - ): + # MatMul weight is inputs[1], Gather weight (embedding table) is inputs[0] + weight_idx = 1 if node.op_type == str(OpType.MatMul) else 0 + if not node.inputs[weight_idx].is_initializer(): logger.debug("skip to quantize %s as it has no initializer", node_name) continue - if node.op_type == OpType.Gather and bits != 4: - logger.warning("Gather only supports 4-bit quantization.") + if node.op_type == str(OpType.Gather) and bits not in (4, 8): + logger.warning( + "Gather quantization is only implemented for 4-bit and 8-bit. Skip node %s (bits=%d).", + node_name, + bits, + ) continue quantized_node, initializer_graph = self._quantize( @@ -133,9 +141,14 @@ def _quantize_model( registered = {} for input_value in quantized_node.inputs: if input_value.const_value is not None: - if input_value.name not in registered: + if input_value.name in globally_registered: + # Already registered by a previous node (shared weight), + # replace with the existing initializer. + ir.convenience.replace_all_uses_with(input_value, globally_registered[input_value.name]) + elif input_value.name not in registered: initializer_graph.register_initializer(input_value) registered[input_value.name] = input_value + globally_registered[input_value.name] = input_value else: logger.debug( "Found duplicated initializer %s, replace all uses with the first one.", @@ -149,6 +162,25 @@ def _quantize_model( else: logger.debug("skip to quantize %s ...", node_name) + # Remove initializers that are no longer referenced by any node. + # After quantization, the original FP32 weight initializers become orphaned + # because the old MatMul/Gather nodes were replaced with MatMulNBits/GatherBlockQuantized + # nodes that reference new INT4 weight initializers instead. + used_names: set[str] = set() + for node in ir_model.graph.all_nodes(): + for inp in node.inputs: + if inp is not None and inp.name: + used_names.add(inp.name) + for out in ir_model.graph.outputs: + if out is not None and out.name: + used_names.add(out.name) + + unused = [name for name in ir_model.graph.initializers if name not in used_names] + for name in unused: + del ir_model.graph.initializers[name] + if unused: + logger.info("Removed %d unused initializers after quantization.", len(unused)) + def _quantize( self, node: ir.Node, bits: int, block_size: int, axis: int, accuracy_level: AccuracyLevel, is_symmetric: bool ) -> tuple[ir.Node, ir.Graph]: @@ -221,7 +253,10 @@ def _quantize_gather( node_initializer = node.inputs[0] data_ndarray = node_initializer.const_value.numpy() data_rank = len(data_ndarray.shape) - quantize_axis = axis + + # ORT GatherBlockQuantized requires quantize_axis == last dimension + # when data is packed as uint8 (two int4 values per byte). + quantize_axis = data_rank - 1 assert -data_rank <= quantize_axis < data_rank, "Invalid quantize axis for Gather node." assert block_size >= 16, "Block size must be greater than or equal to 16." @@ -229,7 +264,7 @@ def _quantize_gather( quantize_axis = (quantize_axis + data_rank) % data_rank quantized_data, scales, zero_point = self._quantize_ndarray( - data_ndarray, quantize_axis, block_size, is_symmetric + data_ndarray, quantize_axis, block_size, is_symmetric, bits ) quantized_data_tensorproto = ir.Value( @@ -250,6 +285,7 @@ def _quantize_gather( "gather_axis": gather_axis, "quantize_axis": quantize_axis, "block_size": block_size, + "bits": bits, } node.outputs[0].name = node.outputs[0].name + f"_Q{bits}" @@ -306,36 +342,46 @@ def _qbits_block_quant( return (packed, scales, zero_point) @staticmethod - def _quant_slice_symmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def _quant_slice_symmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray, np.ndarray]: + qmin = -(1 << (bits - 1)) # -8 for 4-bit, -128 for 8-bit + qmax = (1 << (bits - 1)) - 1 # 7 for 4-bit, 127 for 8-bit max_val = np.max(data, axis=1, keepdims=True) min_val = np.min(data, axis=1, keepdims=True) abs_max = np.where(np.abs(max_val) > np.abs(min_val), max_val, min_val) - scale = abs_max / -8.0 # if max == min, max may be clipped - quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(-8, 7).astype(np.int8) + scale = abs_max / float(qmin) # if max == min, max may be clipped + quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(qmin, qmax).astype(np.int8) return quantized_slice, scale @staticmethod - def _quant_slice_asymmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + def _quant_slice_asymmetric(data: np.ndarray, bits: int = 4) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + qmax = (1 << bits) - 1 # 15 for 4-bit, 255 for 8-bit + mid = 1 << (bits - 1) # 8 for 4-bit, 128 for 8-bit min_val = np.minimum(data.min(axis=1, keepdims=True), 0) max_val = np.maximum(data.max(axis=1, keepdims=True), 0) - scale = (max_val - min_val) / 15.0 - zero_point = np.where(scale == 0, 8, -min_val / scale).round().clip(0, 15).astype(np.uint8) - quantized_slice = np.where(scale == 0, 8, data / scale + zero_point).round().clip(0, 15).astype(np.uint8) + scale = (max_val - min_val) / float(qmax) + zero_point = np.where(scale == 0, mid, -min_val / scale).round().clip(0, qmax).astype(np.uint8) + quantized_slice = np.where(scale == 0, mid, data / scale + zero_point).round().clip(0, qmax).astype(np.uint8) return quantized_slice, scale, zero_point @staticmethod - def _pack_int8_to_int4(data: np.ndarray) -> np.ndarray: - """Pack int8 data to int4 and store in uint8 ndarray.""" - data_flat = data.reshape(-1) - if len(data_flat) % 2 != 0: - data_flat = np.append(data_flat, 0) - quant_data_int4 = (data_flat[::2] & 0xF) | ((data_flat[1::2] & 0xF) << 4) + def _pack_int4_along_axis(data: np.ndarray, axis: int = 1) -> np.ndarray: + """Pack pairs of int4 values into uint8 along the specified axis. - return quant_data_int4.astype("uint8") + Unlike a flat pack, this correctly handles cases where the packing dimension is small + (e.g., zero_point with k_blocks=1) by only pairing values within the same axis slice. + """ + k = data.shape[axis] + if k % 2 != 0: + pad_width = [(0, 0)] * len(data.shape) + pad_width[axis] = (0, 1) + data = np.pad(data, pad_width) + low = np.take(data, range(0, data.shape[axis], 2), axis=axis) + high = np.take(data, range(1, data.shape[axis], 2), axis=axis) + return ((low & 0xF) | ((high & 0xF) << 4)).astype("uint8") def _quantize_ndarray( self, @@ -343,8 +389,9 @@ def _quantize_ndarray( quantize_axis: int, block_size: int, is_symmetric: bool, + bits: int = 4, ) -> Optional[tuple[np.ndarray, np.ndarray, np.ndarray]]: - """Quantize ndarray data to int4 using numpy, return (quantized data, scales).""" + """Quantize ndarray data to int4/int8 using numpy, return (quantized data, scales).""" # Get the shape of the matrix m = 1 # dimension of the matrix before the quantize axis k = data.shape[quantize_axis] # dimension of the matrix along the quantize axis @@ -374,9 +421,9 @@ def _quantize_ndarray( block_slice = data_reshape[:, i:end_idx, :] zero_point_slice = None if is_symmetric: - quantized_slice_int8, scale_slice = self._quant_slice_symmetric(block_slice) + quantized_slice_int8, scale_slice = self._quant_slice_symmetric(block_slice, bits) else: - quantized_slice_int8, scale_slice, zero_point_slice = self._quant_slice_asymmetric(block_slice) + quantized_slice_int8, scale_slice, zero_point_slice = self._quant_slice_asymmetric(block_slice, bits) quant_data_int8[:, i:end_idx, :] = quantized_slice_int8 j = i // block_size @@ -384,10 +431,44 @@ def _quantize_ndarray( if not is_symmetric: zero_point_int8[:, j : (j + 1), :] = zero_point_slice - # pack int8 to int4 - quant_data_int4 = self._pack_int8_to_int4(quant_data_int8) - zero_point_int4 = None - if not is_symmetric: - zero_point_int4 = self._pack_int8_to_int4(zero_point_int8) scales = scales.reshape(scales_shape) - return quant_data_int4, scales, zero_point_int4 + + if bits <= 4: + # pack int8 to int4 + # ORT GatherBlockQuantized uses unsigned int4 representation [0, 15] + # where zero_point=8 is implied for symmetric quantization. + # Convert signed int8 [-8, 7] to unsigned [0, 15] by adding 8. + if is_symmetric: + quant_data_int8 = (quant_data_int8.astype(np.int16) + 8).astype(np.uint8) + + # Pack along axis=1 (the quantize_axis in the 3D view: m, k, n). + # This ensures packing pairs values within the same row, not across rows. + quant_data_int4 = self._pack_int4_along_axis(quant_data_int8, axis=1) + zero_point_int4 = None + if not is_symmetric: + zero_point_int4 = self._pack_int4_along_axis(zero_point_int8, axis=1) + + # Reshape packed data to match original rank (GatherBlockQuantized requires rank > 1). + packed_shape = list(data.shape) + packed_shape[quantize_axis] = (packed_shape[quantize_axis] + 1) // 2 + quant_data_int4 = quant_data_int4.reshape(packed_shape) + if zero_point_int4 is not None: + zp_shape = list(scales_shape) + zp_shape[quantize_axis] = (zp_shape[quantize_axis] + 1) // 2 + zero_point_int4 = zero_point_int4.reshape(zp_shape) + + return quant_data_int4, scales, zero_point_int4 + else: + # 8-bit: no packing needed, one value per byte. + # Convert signed int8 [-128, 127] to unsigned uint8 [0, 255] by adding 128. + if is_symmetric: + quant_data_uint8 = (quant_data_int8.astype(np.int16) + 128).astype(np.uint8) + else: + quant_data_uint8 = quant_data_int8 # already uint8 + + quant_data_uint8 = quant_data_uint8.reshape(data.shape) + zero_point_out = None + if not is_symmetric: + zero_point_out = zero_point_int8.reshape(scales_shape) + + return quant_data_uint8, scales, zero_point_out diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index b1202c45b..a32a06004 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -2295,3 +2295,269 @@ def test_packed_attention_to_packed_mha(tmp_path): op_types = [node.op_type for node in output_model_def.graph.node] assert OpType.PackedAttention not in op_types assert OpType.PackedMultiHeadAttention in op_types + + +# ── GemmToMatMulAdd ────────────────────────────────────────────────────── + + +def test_gemm_to_matmul_add(tmp_path): + """Gemm(A, B, C) with transB=1 → MatMul(A, B^T) + Add(C).""" + input_tensor = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + b_data = np.random.randn(4, 3).astype(np.float32) # transB shape + c_data = np.random.randn(4).astype(np.float32) + + initializers = [ + numpy_helper.from_array(b_data, name="B"), + numpy_helper.from_array(c_data, name="C"), + ] + + gemm_node = helper.make_node("Gemm", ["A", "B", "C"], ["Y"], name="Gemm0", transB=1, alpha=1.0, beta=1.0) + + graph = helper.make_graph([gemm_node], "test", [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 10 + onnx.checker.check_model(model) + + model_path = tmp_path / "gemm.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Gemm" not in op_types + assert "MatMul" in op_types + assert "Add" in op_types + + # Numerical check + a = np.random.randn(2, 3).astype(np.float32) + expected = a @ b_data.T + c_data + + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + actual = sess.run(None, {"A": a})[0] + np.testing.assert_allclose(actual, expected, atol=1e-5) + + +def test_gemm_to_matmul_add_no_bias(tmp_path): + """Gemm(A, B) without bias → MatMul only.""" + input_tensor = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + b_data = np.eye(3, dtype=np.float32) + initializers = [numpy_helper.from_array(b_data, name="B")] + + gemm_node = helper.make_node("Gemm", ["A", "B"], ["Y"], name="Gemm0") + + graph = helper.make_graph([gemm_node], "test", [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 10 + onnx.checker.check_model(model) + + model_path = tmp_path / "gemm_no_bias.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Gemm" not in op_types + assert "MatMul" in op_types + assert "Add" not in op_types + + +def test_gemm_to_matmul_add_skips_non_unit_alpha(tmp_path): + """Gemm with alpha != 1.0 should be left unchanged.""" + input_tensor = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 3]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 3]) + + b_data = np.eye(3, dtype=np.float32) + initializers = [numpy_helper.from_array(b_data, name="B")] + + gemm_node = helper.make_node("Gemm", ["A", "B"], ["Y"], name="Gemm0", alpha=2.0) + + graph = helper.make_graph([gemm_node], "test", [input_tensor], [output_tensor], initializer=initializers) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 10 + + model_path = tmp_path / "gemm_alpha.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "GemmToMatMulAdd"}]}, disable_search=True) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Gemm" in op_types, "Gemm with alpha != 1.0 should be preserved" + + +# ── ReciprocalMulToDiv ─────────────────────────────────────────────────── + + +def test_reciprocal_mul_to_div(tmp_path): + """Reciprocal(x) * a → Div(a, x).""" + input_x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + input_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [2, 4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + nodes = [ + helper.make_node("Reciprocal", ["X"], ["recip_out"], name="Recip"), + helper.make_node("Mul", ["A", "recip_out"], ["Y"], name="MulRecip"), + ] + + graph = helper.make_graph(nodes, "test", [input_x, input_a], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 10 + onnx.checker.check_model(model) + + model_path = tmp_path / "recip_mul.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Reciprocal" not in op_types + assert "Mul" not in op_types + assert "Div" in op_types + + # Numerical check: Div(a, x) == a * (1/x) + x = np.random.randn(2, 4).astype(np.float32) + 2.0 # avoid near-zero + a = np.random.randn(2, 4).astype(np.float32) + expected = a / x + + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + actual = sess.run(None, {"X": x, "A": a})[0] + np.testing.assert_allclose(actual, expected, atol=1e-5) + + +def test_reciprocal_mul_to_div_reversed_order(tmp_path): + """Mul(recip_out, a) — Reciprocal output on the left side.""" + input_x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [4]) + input_a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4]) + + nodes = [ + helper.make_node("Reciprocal", ["X"], ["recip_out"], name="Recip"), + helper.make_node("Mul", ["recip_out", "A"], ["Y"], name="MulRecip"), + ] + + graph = helper.make_graph(nodes, "test", [input_x, input_a], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 10 + + model_path = tmp_path / "recip_mul_rev.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "ReciprocalMulToDiv"}]}, disable_search=True) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + op_types = [n.op_type for n in out_proto.graph.node] + assert "Div" in op_types + + # Check Div inputs: Div(A, X) regardless of Mul operand order + div_node = next(n for n in out_proto.graph.node if n.op_type == "Div") + assert div_node.input[0] == "A" + assert div_node.input[1] == "X" + + +# ── DeduplicateSubgraphInitializers ────────────────────────────────────── + + +def test_deduplicate_subgraph_initializers(tmp_path): + """Duplicate initializers inside a Loop subgraph should be removed.""" + # Build a minimal Loop body with duplicate initializers + body_input = helper.make_tensor_value_info("i", TensorProto.INT64, []) + body_cond_in = helper.make_tensor_value_info("cond_in", TensorProto.BOOL, []) + body_cond_out = helper.make_tensor_value_info("cond_out", TensorProto.BOOL, []) + body_out = helper.make_tensor_value_info("body_out", TensorProto.FLOAT, [2]) + + const_data = np.array([1.0, 2.0], dtype=np.float32) + init1 = numpy_helper.from_array(const_data, name="dup_init") + init2 = numpy_helper.from_array(const_data, name="dup_init") # duplicate + + body_node = helper.make_node("Identity", ["dup_init"], ["body_out"]) + cond_true = numpy_helper.from_array(np.array(True), name="cond_true") + cond_node = helper.make_node("Identity", ["cond_true"], ["cond_out"]) + + body_graph = helper.make_graph( + [cond_node, body_node], + "body", + [body_input, body_cond_in], + [body_cond_out, body_out], + initializer=[init1, init2, cond_true], + ) + + # Outer graph with Loop node + trip_count = helper.make_tensor_value_info("trip_count", TensorProto.INT64, []) + cond = helper.make_tensor_value_info("cond", TensorProto.BOOL, []) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [None, 2]) + + loop_node = helper.make_node("Loop", ["trip_count", "cond"], ["output"], body=body_graph) + + main_graph = helper.make_graph([loop_node], "main", [trip_count, cond], [output]) + model = helper.make_model(main_graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 10 + + model_path = tmp_path / "loop_dup.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "DeduplicateSubgraphInitializers"}]}, + disable_search=True, + ) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + # Check the Loop body: should have only 1 copy of "dup_init" + loop = next(n for n in out_proto.graph.node if n.op_type == "Loop") + body = loop.attribute[0].g + dup_inits = [i for i in body.initializer if i.name == "dup_init"] + assert len(dup_inits) == 1, f"Expected 1, got {len(dup_inits)} copies of dup_init" + + +# ── DeduplicateNodes ───────────────────────────────────────────────────── + + +def test_deduplicate_nodes(tmp_path): + """Nodes that produce the same output tensor name should be deduplicated.""" + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 4]) + + # Two Cast nodes with the same output tensor name (simulating float16.py bug) + cast1 = helper.make_node("Cast", ["X"], ["cast_out"], name="Cast1", to=TensorProto.FLOAT16) + cast2 = helper.make_node("Cast", ["X"], ["cast_out"], name="Cast2", to=TensorProto.FLOAT16) + identity = helper.make_node("Identity", ["cast_out"], ["Y"], name="Id") + + graph = helper.make_graph([cast1, cast2, identity], "test", [input_tensor], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 10 + + model_path = tmp_path / "dup_nodes.onnx" + onnx.save(model, str(model_path)) + input_model = ONNXModelHandler(model_path=str(model_path)) + + p = create_pass_from_dict(GraphSurgeries, {"surgeries": [{"surgeon": "DeduplicateNodes"}]}, disable_search=True) + output_model = p.run(input_model, str(tmp_path / "out")) + out_proto = output_model.load_model() + + cast_nodes = [n for n in out_proto.graph.node if n.op_type == "Cast"] + assert len(cast_nodes) == 1, f"Expected 1 Cast, got {len(cast_nodes)}" + + # The remaining model should still be runnable + sess = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + x = np.random.randn(2, 4).astype(np.float32) + result = sess.run(None, {"X": x})[0] + np.testing.assert_allclose(result, x.astype(np.float16), atol=1e-3) diff --git a/test/passes/onnx/test_peephole_optimizer.py b/test/passes/onnx/test_peephole_optimizer.py index 745289a45..7d54b6909 100644 --- a/test/passes/onnx/test_peephole_optimizer.py +++ b/test/passes/onnx/test_peephole_optimizer.py @@ -5,13 +5,17 @@ from pathlib import Path from unittest.mock import patch +import numpy as np +import onnx +import onnxruntime as ort import pytest from onnx import TensorProto, helper from olive.hardware import DEFAULT_CPU_ACCELERATOR +from olive.model import ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict from olive.passes.onnx.common import model_proto_to_olive_model -from olive.passes.onnx.peephole_optimizer import OnnxPeepholeOptimizer +from olive.passes.onnx.peephole_optimizer import ModelOptimizer, OnnxPeepholeOptimizer from test.utils import get_onnx_model @@ -38,8 +42,6 @@ def test_onnx_peephole_optimizer_pass(tmp_path): # TODO(team): this test will creat an unnecessary intermediate model file. Need to optimize it. def test_onnx_peephole_optimizer_pass_fuse_reshape_operations(tmp_path, external_data_config): - import numpy as np - X = helper.make_tensor_value_info("X", TensorProto.INT64, [None]) # noqa: N806 Y = helper.make_tensor_value_info("Y", TensorProto.INT64, [None]) # noqa: N806 @@ -138,3 +140,142 @@ def test_onnxoptimizer(mock_onnxscript, mock_onnxoptimizer, mock_model_proto_to_ # assert mock_onnxoptimizer.assert_called_once() + + +# ── _ensure_com_microsoft_opset ──────────────────────────────────────────── + + +class TestEnsureComMicrosoftOpset: + """Unit tests for ModelOptimizer.ensure_com_microsoft_opset.""" + + def _make_optimizer_with_model(self, model, tmp_path): + """Save a model to disk and create a ModelOptimizer around it.""" + path = tmp_path / "model.onnx" + onnx.save(model, str(path)) + opt = ModelOptimizer(str(path)) + opt.model = model # use the in-memory model directly + return opt + + def test_adds_opset_when_missing(self, tmp_path): + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + assert not any(op.domain == "com.microsoft" for op in model.opset_import) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + assert any(op.domain == "com.microsoft" for op in opt.model.opset_import) + + def test_does_not_duplicate_opset(self, tmp_path): + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17), helper.make_opsetid("com.microsoft", 1)], + ) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + ms_opsets = [op for op in opt.model.opset_import if op.domain == "com.microsoft"] + assert len(ms_opsets) == 1 + + def test_adds_opset_to_functions(self, tmp_path): + func = onnx.FunctionProto() + func.name = "test_func" + func.domain = "test.domain" + func.opset_import.append(helper.make_opsetid("", 17)) + + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + model.functions.append(func) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + func_domains = {op.domain for op in opt.model.functions[0].opset_import} + assert "com.microsoft" in func_domains + + def test_skips_function_with_existing_opset(self, tmp_path): + func = onnx.FunctionProto() + func.name = "test_func" + func.domain = "test.domain" + func.opset_import.append(helper.make_opsetid("", 17)) + func.opset_import.append(helper.make_opsetid("com.microsoft", 1)) + + model = helper.make_model( + helper.make_graph([], "g", [], []), + opset_imports=[helper.make_opsetid("", 17)], + ) + model.functions.append(func) + + opt = self._make_optimizer_with_model(model, tmp_path) + opt.ensure_com_microsoft_opset() + + ms_opsets = [op for op in opt.model.functions[0].opset_import if op.domain == "com.microsoft"] + assert len(ms_opsets) == 1 + + +# ── Cast chain elimination via OnnxPeepholeOptimizer ─────────────────────── + + +class TestCastChainElimination: + """Tests for cast chain elimination integrated into OnnxPeepholeOptimizer.""" + + @pytest.fixture + def cast_chain_model_path(self, tmp_path): + """Model with a redundant Cast chain: fp32 → fp16 → fp32.""" + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 4]) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + cast_to_fp16 = helper.make_node("Cast", ["X"], ["x_fp16"], to=TensorProto.FLOAT16) + cast_back = helper.make_node("Cast", ["x_fp16"], ["Y"], to=TensorProto.FLOAT) + + graph = helper.make_graph([cast_to_fp16, cast_back], "cast_chain", [input_tensor], [output_tensor]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 17)]) + model.ir_version = 10 # Compatible with ORT version in CI + onnx.checker.check_model(model) + + path = tmp_path / "cast_chain.onnx" + onnx.save(model, str(path)) + return path + + def test_eliminates_redundant_cast_chain(self, cast_chain_model_path, tmp_path): + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict(OnnxPeepholeOptimizer, {"cast_chain_elimination": True}, disable_search=True) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + result_model = onnx.load(output.model_path) + # The rewrite rule collapses the round-trip fp32→fp16→fp32 chain + # into a single Identity node. + assert len(result_model.graph.node) == 1 + assert result_model.graph.node[0].op_type == "Identity" + + def test_opset_fixup_applied(self, cast_chain_model_path, tmp_path): + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict( + OnnxPeepholeOptimizer, + {"fix_com_microsoft_opset": True, "cast_chain_elimination": False}, + disable_search=True, + ) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + result_model = onnx.load(output.model_path) + assert any(op.domain == "com.microsoft" for op in result_model.opset_import) + + def test_numerical_correctness(self, cast_chain_model_path, tmp_path): + """Optimized model should produce the same output as the original.""" + olive_model = ONNXModelHandler(model_path=str(cast_chain_model_path)) + p = create_pass_from_dict(OnnxPeepholeOptimizer, {"cast_chain_elimination": True}, disable_search=True) + output = p.run(olive_model, str(tmp_path / "out.onnx")) + + x = np.random.randn(2, 4).astype(np.float32) + + orig_sess = ort.InferenceSession(str(cast_chain_model_path), providers=["CPUExecutionProvider"]) + opt_sess = ort.InferenceSession(output.model_path, providers=["CPUExecutionProvider"]) + + orig_out = orig_sess.run(None, {"X": x})[0] + opt_out = opt_sess.run(None, {"X": x})[0] + np.testing.assert_allclose(orig_out, opt_out, atol=1e-3) diff --git a/test/passes/onnx/test_rtn_quantization.py b/test/passes/onnx/test_rtn_quantization.py index 2dc6f333e..97928ee18 100644 --- a/test/passes/onnx/test_rtn_quantization.py +++ b/test/passes/onnx/test_rtn_quantization.py @@ -154,6 +154,8 @@ def test_rtn_quantization_pass_gather(self, gather_model_path, tmp_path, is_symm ir_model = ir.load(quantized_model.model_path) # Assert + # ORT GatherBlockQuantized requires quantize_axis == last dimension (data_rank - 1). + # The gather model fixture uses 2D data [100, 64], so quantize_axis = 1. found_gather_block_quantized = False for node in ir_model.graph.all_nodes(): if node.op_type == str(OpType.GatherBlockQuantized): @@ -163,10 +165,7 @@ def test_rtn_quantization_pass_gather(self, gather_model_path, tmp_path, is_symm attr.name == "block_size" and attr.value == pass_config["block_size"] for attr in node.attributes.values() ) - assert any( - attr.name == "quantize_axis" and attr.value == pass_config["axis"] - for attr in node.attributes.values() - ) + assert any(attr.name == "quantize_axis" and attr.value == 1 for attr in node.attributes.values()) break assert found_gather_block_quantized, "No GatherBlockQuantized node found in quantized model" @@ -204,3 +203,155 @@ def test_rtn_quantization_with_exclusion(self, matmul_model_path, tmp_path): assert not found_matmul_nbits, "MatMulNBits node found despite exclusion" assert found_original_matmul, "Original MatMul node should still exist when excluded" + + @pytest.mark.parametrize("is_symmetric", [True, False]) + def test_rtn_quantization_gather_8bit(self, gather_model_path, tmp_path, is_symmetric): + """8-bit Gather quantization should produce GatherBlockQuantized with bits=8.""" + olive_model = ONNXModelHandler(model_path=str(gather_model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", + ) + pass_config = {"bits": 8, "block_size": 128, "axis": 0, "is_symmetric": is_symmetric} + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec + ) + + output_path = tmp_path / "quantized_gather_8bit.onnx" + quantized_model = p.run(olive_model, output_path) + + assert os.path.exists(quantized_model.model_path) + + ir_model = ir.load(quantized_model.model_path) + + found = False + for node in ir_model.graph.all_nodes(): + if node.op_type == str(OpType.GatherBlockQuantized): + found = True + assert node.domain == MSFT_DOMAIN + # bits attribute must be 8 + assert any(attr.name == "bits" and attr.value == 8 for attr in node.attributes.values()), ( + "GatherBlockQuantized should have bits=8" + ) + # quantize_axis must be last dimension (data_rank - 1) + assert any(attr.name == "quantize_axis" and attr.value == 1 for attr in node.attributes.values()), ( + "quantize_axis should be forced to last dim (1 for 2-D embedding)" + ) + break + + assert found, "No GatherBlockQuantized node found for 8-bit quantization" + + def test_rtn_quantization_gather_quantize_axis_forced_to_last_dim(self, gather_model_path, tmp_path): + """Regardless of axis config, gather quantize_axis is forced to data_rank - 1.""" + olive_model = ONNXModelHandler(model_path=str(gather_model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", + ) + # Set axis=0, but the code should force quantize_axis to last dim + pass_config = {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True} + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, pass_config, disable_search=True, accelerator_spec=accelerator_spec + ) + + output_path = tmp_path / "quantized_gather_axis.onnx" + quantized_model = p.run(olive_model, output_path) + + ir_model = ir.load(quantized_model.model_path) + + found = False + for node in ir_model.graph.all_nodes(): + if node.op_type == str(OpType.GatherBlockQuantized): + found = True + qa = [attr for attr in node.attributes.values() if attr.name == "quantize_axis"] + assert len(qa) == 1 + assert qa[0].value == 1, f"quantize_axis should be 1 (last dim of 2-D data), got {qa[0].value}" + break + + assert found, "No GatherBlockQuantized node found for axis/quantize_axis test" + + def test_rtn_quantization_shared_gather_weights(self, tmp_path): + """Two Gather nodes sharing the same weight should not duplicate initializers.""" + data_shape = [100, 64] + data_tensor = np.random.randn(*data_shape).astype(np.float32) + data_name = "shared_data" + + data_init = onnx.helper.make_tensor( + name=data_name, + data_type=onnx.TensorProto.FLOAT, + dims=data_shape, + vals=data_tensor.flatten().tolist(), + ) + indices1 = onnx.helper.make_tensor_value_info("indices1", onnx.TensorProto.INT64, [1, 5]) + indices2 = onnx.helper.make_tensor_value_info("indices2", onnx.TensorProto.INT64, [1, 5]) + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [1, 5, 64]) + out2 = onnx.helper.make_tensor_value_info("out2", onnx.TensorProto.FLOAT, [1, 5, 64]) + + gather1 = onnx.helper.make_node("Gather", [data_name, "indices1"], ["out1"], name="Gather1") + gather2 = onnx.helper.make_node("Gather", [data_name, "indices2"], ["out2"], name="Gather2") + + graph = onnx.helper.make_graph( + [gather1, gather2], + "shared_weight_test", + [indices1, indices2], + [out1, out2], + initializer=[data_init], + ) + model = onnx.helper.make_model(graph, producer_name="olive-test") + model.opset_import[0].version = 13 + model.ir_version = 10 + + model_path = tmp_path / "shared_gather.onnx" + onnx.save(model, str(model_path)) + + olive_model = ONNXModelHandler(model_path=str(model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", + ) + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, + {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + output_path = tmp_path / "shared_gather_quantized.onnx" + quantized_model = p.run(olive_model, output_path) + + ir_model = ir.load(quantized_model.model_path) + + # Both nodes should be GatherBlockQuantized + gbq_nodes = [n for n in ir_model.graph.all_nodes() if n.op_type == str(OpType.GatherBlockQuantized)] + assert len(gbq_nodes) == 2, f"Expected 2 GatherBlockQuantized nodes, got {len(gbq_nodes)}" + + # The quantized data inputs (first input) should refer to the same name + quant_data_names = [n.inputs[0].name for n in gbq_nodes] + assert quant_data_names[0] == quant_data_names[1], ( + f"Shared weight should produce same quantized initializer name: {quant_data_names}" + ) + + def test_rtn_quantization_removes_unused_initializers(self, matmul_model_path, tmp_path): + """After quantization, original FP32 weight initializers should be removed.""" + olive_model = ONNXModelHandler(model_path=str(matmul_model_path)) + accelerator_spec = AcceleratorSpec( + accelerator_type="CPU", + execution_provider="CPUExecutionProvider", + ) + p = create_pass_from_dict( + OnnxBlockWiseRtnQuantization, + {"bits": 4, "block_size": 128, "axis": 0, "is_symmetric": True}, + disable_search=True, + accelerator_spec=accelerator_spec, + ) + + output_path = tmp_path / "unused_init_test.onnx" + quantized_model = p.run(olive_model, output_path) + + ir_model = ir.load(quantized_model.model_path) + + # The original "weight" initializer should be gone + init_names = set(ir_model.graph.initializers.keys()) + assert "weight" not in init_names, ( + f"Original FP32 'weight' initializer should have been removed, found: {init_names}" + )