Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
aa0582e
Add Qwen3-VL / Qwen2.5-VL ONNX export support
hanbitmyths Feb 26, 2026
514362d
Fix ModelBuilder sys.path for ort-genai builders package import
hanbitmyths Feb 27, 2026
cb1987b
Expose real ModelBuilder import error for debugging
hanbitmyths Feb 27, 2026
2c2269e
Clean up ModelBuilder import fix (expose chain, not debug print)
hanbitmyths Feb 27, 2026
e77864f
Remove sys.path hack for onnxruntime-genai builder import
hanbitmyths Feb 27, 2026
af5983f
Add 8-bit Gather quantization support, ByteSize crash fix, and graph …
hanbitmyths Mar 3, 2026
4d5283e
Add unit tests for Qwen3-VL graph surgery and quantization passes
hanbitmyths Mar 4, 2026
9fc9bd3
Fix lintrunner warnings: rename uppercase variables (N806), add TODO …
hanbitmyths Mar 4, 2026
32cc2ce
Merge branch 'main' into sunghcho/qwen3-vl
hanbitmyths Mar 4, 2026
74b257c
Fix ruff formatting, int4 packing bug, and test assertion
hanbitmyths Mar 4, 2026
62544da
Add linkcheck_ignore for broken intel/neural-compressor URL
hanbitmyths Mar 4, 2026
efe845f
Merge branch 'main' into sunghcho/qwen3-vl
hanbitmyths Mar 6, 2026
3d0029c
Remove neural-compressor linkcheck_ignore (fixed upstream in #2351)
hanbitmyths Mar 6, 2026
5ad0fa4
Merge branch 'main' into sunghcho/qwen3-vl
hanbitmyths Mar 12, 2026
448e8a2
Trigger CI rebuild
hanbitmyths Mar 12, 2026
b41c25f
Trigger CI rebuild (lint)
hanbitmyths Mar 12, 2026
a35f6e9
Trigger CI rebuild (all green)
hanbitmyths Mar 12, 2026
9846f31
Trigger CI rebuild (CodeQL)
hanbitmyths Mar 12, 2026
d5d1e58
Replace ORT-based cast chain elimination with onnxscript optimizer
hanbitmyths Mar 13, 2026
f8146c5
Merge origin/main into sunghcho/qwen3-vl
hanbitmyths Mar 13, 2026
15975c8
Replace onnxscript optimizer with targeted rewrite rule for cast chai…
hanbitmyths Mar 16, 2026
4ecba49
Fix lint: move onnxscript imports to top level (PLC0415)
hanbitmyths Mar 16, 2026
054bd7c
Fix lint: use functional RewriteRule API to avoid pylint W0221 (argum…
hanbitmyths Mar 16, 2026
9c54059
Merge cast chain elimination into OnnxPeepholeOptimizer
hanbitmyths Mar 17, 2026
9578497
Move _get_cast_chain_rewrite_rules into ModelOptimizer as static method
hanbitmyths Mar 17, 2026
7a2e634
Make all ModelOptimizer steps configurable in OnnxPeepholeOptimizer
hanbitmyths Mar 17, 2026
f50743d
Fix lint: remove duplicate numpy import in test (W0621/W0404)
hanbitmyths Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion olive/passes/onnx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
import onnx
from onnx import external_data_helper
from onnxscript import ir
from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY

# TODO(sunghcho): Remove try/except once onnxscript >= 0.2.0 (which exports FOLDED_FROM_KEY) is the minimum
# required version. After that, replace with: from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY
try:
from onnxscript.optimizer._constant_folding import FOLDED_FROM_KEY
except ImportError:
FOLDED_FROM_KEY = "pkg.onnxscript.optimizer.constant_folding.folded_from"

from olive.common.utils import StrEnumBase, hardlink_copy_file
from olive.model import CompositeModelHandler, ONNXModelHandler
Expand Down
244 changes: 244 additions & 0 deletions olive/passes/onnx/graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,6 +1233,93 @@ def __call__(self, model: ModelProto):
return dag.model


class GemmToMatMulAdd(ProtoSurgeon):
"""Replace Gemm with MatMul (+ Add) for INT4 quantization compatibility.

The INT4 RTN quantizer only recognizes MatMul nodes. This surgeon converts
Gemm nodes back to MatMul+Add so that the weight matrices become eligible
for block-wise quantization.

Handles transB by transposing constant weights in-place or inserting a
Transpose node for non-constant weights. Skips Gemm nodes whose alpha/beta
are not 1.0 or whose transA is set.
"""

def __call__(self, model: ModelProto):
from onnx import helper, numpy_helper

graph = model.graph
initializer_map = {init.name: init for init in graph.initializer}
nodes_to_remove = []
nodes_to_add = []

for node in graph.node:
if node.op_type != "Gemm":
continue

alpha = beta = 1.0
trans_a = trans_b = 0
for attr in node.attribute:
if attr.name == "alpha":
alpha = attr.f
elif attr.name == "beta":
beta = attr.f
elif attr.name == "transA":
trans_a = attr.i
elif attr.name == "transB":
trans_b = attr.i

if alpha != 1.0 or beta != 1.0 or trans_a != 0:
continue

inp_a, inp_b = node.input[0], node.input[1]
inp_c = node.input[2] if len(node.input) > 2 else None
out_y = node.output[0]

if trans_b:
if inp_b in initializer_map:
init = initializer_map[inp_b]
w_t = numpy_helper.to_array(init).T.copy()
new_init = numpy_helper.from_array(w_t, name=inp_b)
for i, existing in enumerate(graph.initializer):
if existing.name == inp_b:
graph.initializer[i].CopyFrom(new_init)
break
matmul_rhs = inp_b
else:
transpose_out = f"{node.name}_transpose_B"
nodes_to_add.append(
helper.make_node(
"Transpose", [inp_b], [transpose_out], name=f"{node.name}_Transpose", perm=[1, 0]
)
)
matmul_rhs = transpose_out
else:
matmul_rhs = inp_b

if inp_c:
matmul_out = f"{node.name}_matmul_out"
nodes_to_add.append(
helper.make_node("MatMul", [inp_a, matmul_rhs], [matmul_out], name=f"{node.name}_MatMul")
)
nodes_to_add.append(helper.make_node("Add", [matmul_out, inp_c], [out_y], name=f"{node.name}_Add"))
else:
nodes_to_add.append(
helper.make_node("MatMul", [inp_a, matmul_rhs], [out_y], name=f"{node.name}_MatMul")
)

nodes_to_remove.append(node)

for node in nodes_to_remove:
graph.node.remove(node)
graph.node.extend(nodes_to_add)

if nodes_to_remove:
logger.debug("Replaced %d Gemm nodes with MatMul + Add nodes", len(nodes_to_remove))

return model


class RemoveRopeMultiCache(ProtoSurgeon):
"""Remove the multi rope cache from the model."""

Expand Down Expand Up @@ -2041,6 +2128,163 @@ def equal_weights(self, dag: OnnxDAG, init0: str, init1: str, transpose: bool =
return np.array_equal(arr0.ravel(), arr1.ravel())


class ReciprocalMulToDiv(ProtoSurgeon):
"""Replace Reciprocal(x) * a with Div(a, x).

Before:
[x] --> Reciprocal --> Mul --> [out]
^
|
[a]

After:
[a] --> Div --> [out]
^
|
[x]

Why this is needed:
PyTorch's ``torch.rsqrt()`` (used by Qwen2.5-VL's ``Qwen2RMSNorm``) decomposes to
``Sqrt -> Reciprocal -> Mul`` in ONNX. ORT's ``SimplifiedLayerNormFusion`` only
matches the pattern ``Pow -> ReduceMean -> Add -> Sqrt -> Div -> Mul`` — it does
**not** recognize the ``Reciprocal -> Mul`` variant (confirmed on ORT main as of
2025-06). This pass canonicalizes the graph so that the fusion fires, replacing
decomposed RMSNorm with a single ``SimplifiedLayerNormalization`` op.

When to use:
Run **before** ``OrtTransformersOptimization`` on models whose normalization layers
export ``rsqrt`` as ``Reciprocal`` (e.g. HuggingFace models using ``torch.rsqrt``).
"""

def __call__(self, model: ModelProto):
modified = 0
nodes_to_remove = []

for node in model.graph.node:
if node.op_type != "Reciprocal":
continue

recip_input = node.input[0] # x
recip_output = node.output[0]

# Find Mul consumers of this Reciprocal
mul_nodes = [n for n in model.graph.node if n.op_type == "Mul" and recip_output in n.input]

for mul_node in mul_nodes:
# Identify the other operand (not from Reciprocal)
if mul_node.input[0] == recip_output:
other_input = mul_node.input[1]
else:
other_input = mul_node.input[0]

# Convert Mul(a, Reciprocal(x)) to Div(a, x) in-place
mul_node.op_type = "Div"
mul_node.input[0] = other_input
mul_node.input[1] = recip_input
if mul_node.name:
mul_node.name = self.create_new_name(mul_node.name, "Mul", "Div")
modified += 1

# If no more consumers of Reciprocal output, mark for removal
remaining = [n for n in model.graph.node if n != node and recip_output in n.input]
if not remaining:
nodes_to_remove.append(node)

for node in nodes_to_remove:
model.graph.node.remove(node)

if modified > 0:
logger.debug("Replaced %d Reciprocal+Mul patterns with Div", modified)

return model


class DeduplicateSubgraphInitializers(ProtoSurgeon):
"""Remove duplicate initializers in Loop / If / Scan subgraphs.

Why this is needed:
ORT's graph optimizer (constant folding, shape inference, etc.) may copy
initializers into subgraphs that already contain them, creating entries with
identical names. ORT's ``ConstantSharing`` pass explicitly skips subgraph
usage (``constant_sharing.cc``: "If usage is from subgraph, skip it now"),
so these duplicates are never cleaned up. Duplicate initializers violate
the ONNX spec's unique-name requirement and can cause validation failures
or silent data corruption.

What it does:
For every ``Loop`` / ``If`` / ``Scan`` subgraph, keeps the first initializer
with a given name and removes all subsequent duplicates.

When to use:
Run **after** ``OrtTransformersOptimization`` (which introduces the duplicates)
and **before** any pass that serializes or validates the model.
"""

def __call__(self, model: ModelProto):
removed = 0
for node in model.graph.node:
for attr in node.attribute:
if attr.g and attr.g.initializer:
seen = set()
to_remove = []
for init in attr.g.initializer:
if init.name in seen:
to_remove.append(init)
else:
seen.add(init.name)
for init in to_remove:
attr.g.initializer.remove(init)
removed += 1
if removed > 0:
logger.debug("Removed %d duplicate subgraph initializers", removed)
return model


class DeduplicateNodes(ProtoSurgeon):
"""Remove nodes whose output tensors are already produced by an earlier node.

Before (invalid — two nodes define the same tensor ``/Cast_output_0``):
NodeA --> Cast --> /Cast_output_0
NodeB --> Cast --> /Cast_output_0 (duplicate, removed)

After:
NodeA --> Cast --> /Cast_output_0

Why this is needed:
ORT's ``convert_float_to_float16`` (``float16.py``) may insert identical
``Cast`` nodes in parallel branches that each declare the same output tensor
name. The ONNX spec requires every tensor to have a unique producer; loading
a model with duplicate producers causes ``onnxruntime.InferenceSession`` to
fail with a duplicate-definition error.

What it does:
Scans nodes in graph order and records each output tensor name. If a later
node produces a tensor name that was already seen, the entire node is removed.

When to use:
Run **after** ``OnnxFloatToFloat16`` as a cleanup step.
"""

def __call__(self, model: ModelProto):
output_seen: set[str] = set()
indices_to_remove: list[int] = []
for i, node in enumerate(model.graph.node):
dup = False
for o in node.output:
if o and o in output_seen:
dup = True
break
if o:
output_seen.add(o)
if dup:
indices_to_remove.append(i)
for i in reversed(indices_to_remove):
del model.graph.node[i]
if indices_to_remove:
logger.debug("Removed %d duplicate nodes", len(indices_to_remove))
return model


class PackedAttentionToLoopMHA(Surgeon):
"""Replace custom::PackedAttention with a loop calling com.microsoft::MultiHeadAttention.

Expand Down
Loading
Loading