Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 214 additions & 1 deletion olive/passes/onnx/graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numpy as np
import onnx
import onnxscript
from onnx import ModelProto, TensorProto
from onnx import ModelProto, TensorProto, numpy_helper
from onnx.helper import make_tensor
from onnx_ir.passes.common import DeduplicateHashedInitializersPass, InlinePass, RemoveUnusedOpsetsPass
from onnxscript import ir, rewriter
Expand Down Expand Up @@ -850,6 +850,219 @@ def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> list[str] | None:
return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else []


class SimplifiedLayerNormToRMSNorm(ProtoSurgeon):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please add a corresponding test like in https://github.com/microsoft/Olive/blob/main/test/passes/onnx/test_graph_surgeries.py#L523? Thanks!

"""Replace SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization with an RMSNorm subgraph built from elementwise ops.

RMS(x) = sqrt(mean(x^2, axis=-1, keepdims=1) + eps)
y = (x / RMS(x)) * gamma

For SkipSimplifiedLayerNormalization, we first do:
s = input + skip
and use 's' as x for RMSNorm. If the original node exposes a second output
(residual sum), we rewire its consumers to 's' to preserve graph behavior.

IMPORTANT: ReduceMean schema change across opsets:
- opset < 18: axes is an ATTRIBUTE
- opset >=18: axes is an INPUT tensor (int64), keepdims remains an attribute.
"""

def __call__(self, model: onnx.ModelProto):
dag = OnnxDAG(model)

# Determine the default ONNX opset for the main domain ("", "ai.onnx").
# We'll use this to decide how to build ReduceMean.
default_opset = None
for imp in model.opset_import:
if imp.domain in ("", "ai.onnx"):
default_opset = imp.version
break
if default_opset is None:
# Fall back defensively; most models have a default import.
default_opset = 13

use_axes_input_for_reduce_mean = default_opset >= 18

modified = 0

for node_name in dag.get_node_names():
op_type = dag.get_node_op_type(node_name)
if op_type not in {"SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"}:
continue

graph_idx = dag.get_graph_idx(node_name)
inputs = dag.get_node_inputs(node_name, True)
outputs = dag.get_node_outputs(node_name, True)

# ---------------------------
# Build the input to be normalized: ln_input
# ---------------------------
if op_type == "SkipSimplifiedLayerNormalization":
# Expect inputs: [input, skip, gamma]
if len(inputs) != 3:
continue
root1, root2, gamma = inputs

# Add(input, skip) => skip_add_out
skip_add_name = self.create_new_name(node_name, op_type, "Add")
skip_add_out = f"{skip_add_name}_out"
skip_add_node = onnx.helper.make_node(
"Add",
inputs=[root1, root2],
outputs=[skip_add_out],
name=skip_add_name,
)
dag.add_node(skip_add_node, graph_idx)

ln_input = skip_add_out
else:
# SimplifiedLayerNormalization: inputs = [x, gamma]
if len(inputs) != 2:
continue
ln_input, gamma = inputs

# The original primary output (normalized tensor)
ln_output = outputs[0]

# ---------------------------
# Step 1: Pow(x, 2)
# ---------------------------
pow_name = self.create_new_name(node_name, op_type, "Pow")
pow_out = f"{pow_name}_out"
pow_const = numpy_helper.from_array(np.array([2.0], dtype=np.float32), name=f"{pow_name}_const")
dag.add_initializer(pow_const, graph_idx)
pow_node = onnx.helper.make_node(
"Pow",
inputs=[ln_input, pow_const.name],
outputs=[pow_out],
name=pow_name,
)
dag.add_node(pow_node, graph_idx)

# ---------------------------
# Step 2: ReduceMean over last dim, keepdims=1
# - opset < 18 : axes is an attribute
# - opset >= 18: axes is an input tensor (INT64)
# ---------------------------
mean_name = self.create_new_name(node_name, op_type, "ReduceMean")
mean_out = f"{mean_name}_out"

if use_axes_input_for_reduce_mean:
axes_init = numpy_helper.from_array(np.array([-1], dtype=np.int64), name=f"{mean_name}_axes")
dag.add_initializer(axes_init, graph_idx)

mean_node = onnx.helper.make_node(
"ReduceMean",
inputs=[pow_out, axes_init.name],
outputs=[mean_out],
name=mean_name,
keepdims=1,
)
else:
# Older schema: axes is an attribute
mean_node = onnx.helper.make_node(
"ReduceMean",
inputs=[pow_out],
outputs=[mean_out],
name=mean_name,
axes=[-1],
keepdims=1,
)
dag.add_node(mean_node, graph_idx)

# ---------------------------
# Step 3: Add epsilon
# ---------------------------
eps_value = 1e-06
add_eps_name = self.create_new_name(node_name, op_type, "AddEps")
add_eps_out = f"{add_eps_name}_out"

eps_const = numpy_helper.from_array(np.array([eps_value], dtype=np.float32), name=f"{add_eps_name}_const")
dag.add_initializer(eps_const, graph_idx)

add_eps_node = onnx.helper.make_node(
"Add",
inputs=[mean_out, eps_const.name],
outputs=[add_eps_out],
name=add_eps_name,
)
dag.add_node(add_eps_node, graph_idx)

# ---------------------------
# Step 4: Sqrt
# ---------------------------
sqrt_name = self.create_new_name(node_name, op_type, "Sqrt")
sqrt_out = f"{sqrt_name}_out"
sqrt_node = onnx.helper.make_node(
"Sqrt",
inputs=[add_eps_out],
outputs=[sqrt_out],
name=sqrt_name,
)
dag.add_node(sqrt_node, graph_idx)

# ---------------------------
# Step 5: Div (x / sqrt(...))
# ---------------------------
div_name = self.create_new_name(node_name, op_type, "Div")
div_out = f"{div_name}_out"
div_node = onnx.helper.make_node(
"Div",
inputs=[ln_input, sqrt_out],
outputs=[div_out],
name=div_name,
)
dag.add_node(div_node, graph_idx)

# ---------------------------
# Step 6: Mul with gamma
# ---------------------------
mul_name = self.create_new_name(node_name, op_type, "Mul")
mul_out = f"{mul_name}_out"
mul_node = onnx.helper.make_node(
"Mul",
inputs=[div_out, gamma],
outputs=[mul_out],
name=mul_name,
)
dag.add_node(mul_node, graph_idx)

# ---------------------------
# Rewire consumers of the original main output
# ---------------------------
for consumer in dag.get_consumers(ln_output):
dag.replace_node_input(consumer, ln_output, mul_out)

# ---------------------------
# For SkipSimplifiedLayerNormalization that had two outputs:
# - Output 1 is typically residual sum (input_skip_bias_sum)
# - Redirect its consumers to the skip-sum Add output
# ---------------------------
if op_type == "SkipSimplifiedLayerNormalization" and len(outputs) == 2:
second_output = outputs[1]

second_vi = dag.get_value_info_proto(second_output)
if second_vi is not None:
new_vi = onnx.ValueInfoProto()
new_vi.CopyFrom(second_vi)
new_vi.name = skip_add_out
dag.add_value_info(new_vi, graph_idx)

# Redirect all consumers of the second output
for consumer in dag.get_consumers(second_output):
dag.replace_node_input(consumer, second_output, skip_add_out)

dag.remove_node(node_name)
modified += 1

if modified > 0:
logger.debug(
"Replaced %d Simplified/SkipSimplifiedLayerNormalization nodes with RMSNorm subgraphs", modified
)

dag.update()
return dag.model


class SimplifiedLayerNormToL2Norm(ProtoSurgeon):
"""Replace Skip/SimplifiedLayerNormalization node with L2Norm subgraph.

Expand Down
Loading