diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 7d1c41119f..9101bf7a6f 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -16,7 +16,7 @@ import numpy as np import onnx import onnxscript -from onnx import ModelProto, TensorProto +from onnx import ModelProto, TensorProto, numpy_helper from onnx.helper import make_tensor from onnx_ir.passes.common import DeduplicateHashedInitializersPass, InlinePass, RemoveUnusedOpsetsPass from onnxscript import ir, rewriter @@ -850,6 +850,219 @@ def get_rmsnorm_nodes(pow_node: str, dag: OnnxDAG) -> list[str] | None: return rmsnorm_nodes if len(rmsnorm_nodes) >= (len(pattern) - 1) else [] +class SimplifiedLayerNormToRMSNorm(ProtoSurgeon): + """Replace SimplifiedLayerNormalization or SkipSimplifiedLayerNormalization with an RMSNorm subgraph built from elementwise ops. + + RMS(x) = sqrt(mean(x^2, axis=-1, keepdims=1) + eps) + y = (x / RMS(x)) * gamma + + For SkipSimplifiedLayerNormalization, we first do: + s = input + skip + and use 's' as x for RMSNorm. If the original node exposes a second output + (residual sum), we rewire its consumers to 's' to preserve graph behavior. + + IMPORTANT: ReduceMean schema change across opsets: + - opset < 18: axes is an ATTRIBUTE + - opset >=18: axes is an INPUT tensor (int64), keepdims remains an attribute. + """ + + def __call__(self, model: onnx.ModelProto): + dag = OnnxDAG(model) + + # Determine the default ONNX opset for the main domain ("", "ai.onnx"). + # We'll use this to decide how to build ReduceMean. + default_opset = None + for imp in model.opset_import: + if imp.domain in ("", "ai.onnx"): + default_opset = imp.version + break + if default_opset is None: + # Fall back defensively; most models have a default import. + default_opset = 13 + + use_axes_input_for_reduce_mean = default_opset >= 18 + + modified = 0 + + for node_name in dag.get_node_names(): + op_type = dag.get_node_op_type(node_name) + if op_type not in {"SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization"}: + continue + + graph_idx = dag.get_graph_idx(node_name) + inputs = dag.get_node_inputs(node_name, True) + outputs = dag.get_node_outputs(node_name, True) + + # --------------------------- + # Build the input to be normalized: ln_input + # --------------------------- + if op_type == "SkipSimplifiedLayerNormalization": + # Expect inputs: [input, skip, gamma] + if len(inputs) != 3: + continue + root1, root2, gamma = inputs + + # Add(input, skip) => skip_add_out + skip_add_name = self.create_new_name(node_name, op_type, "Add") + skip_add_out = f"{skip_add_name}_out" + skip_add_node = onnx.helper.make_node( + "Add", + inputs=[root1, root2], + outputs=[skip_add_out], + name=skip_add_name, + ) + dag.add_node(skip_add_node, graph_idx) + + ln_input = skip_add_out + else: + # SimplifiedLayerNormalization: inputs = [x, gamma] + if len(inputs) != 2: + continue + ln_input, gamma = inputs + + # The original primary output (normalized tensor) + ln_output = outputs[0] + + # --------------------------- + # Step 1: Pow(x, 2) + # --------------------------- + pow_name = self.create_new_name(node_name, op_type, "Pow") + pow_out = f"{pow_name}_out" + pow_const = numpy_helper.from_array(np.array([2.0], dtype=np.float32), name=f"{pow_name}_const") + dag.add_initializer(pow_const, graph_idx) + pow_node = onnx.helper.make_node( + "Pow", + inputs=[ln_input, pow_const.name], + outputs=[pow_out], + name=pow_name, + ) + dag.add_node(pow_node, graph_idx) + + # --------------------------- + # Step 2: ReduceMean over last dim, keepdims=1 + # - opset < 18 : axes is an attribute + # - opset >= 18: axes is an input tensor (INT64) + # --------------------------- + mean_name = self.create_new_name(node_name, op_type, "ReduceMean") + mean_out = f"{mean_name}_out" + + if use_axes_input_for_reduce_mean: + axes_init = numpy_helper.from_array(np.array([-1], dtype=np.int64), name=f"{mean_name}_axes") + dag.add_initializer(axes_init, graph_idx) + + mean_node = onnx.helper.make_node( + "ReduceMean", + inputs=[pow_out, axes_init.name], + outputs=[mean_out], + name=mean_name, + keepdims=1, + ) + else: + # Older schema: axes is an attribute + mean_node = onnx.helper.make_node( + "ReduceMean", + inputs=[pow_out], + outputs=[mean_out], + name=mean_name, + axes=[-1], + keepdims=1, + ) + dag.add_node(mean_node, graph_idx) + + # --------------------------- + # Step 3: Add epsilon + # --------------------------- + eps_value = 1e-06 + add_eps_name = self.create_new_name(node_name, op_type, "AddEps") + add_eps_out = f"{add_eps_name}_out" + + eps_const = numpy_helper.from_array(np.array([eps_value], dtype=np.float32), name=f"{add_eps_name}_const") + dag.add_initializer(eps_const, graph_idx) + + add_eps_node = onnx.helper.make_node( + "Add", + inputs=[mean_out, eps_const.name], + outputs=[add_eps_out], + name=add_eps_name, + ) + dag.add_node(add_eps_node, graph_idx) + + # --------------------------- + # Step 4: Sqrt + # --------------------------- + sqrt_name = self.create_new_name(node_name, op_type, "Sqrt") + sqrt_out = f"{sqrt_name}_out" + sqrt_node = onnx.helper.make_node( + "Sqrt", + inputs=[add_eps_out], + outputs=[sqrt_out], + name=sqrt_name, + ) + dag.add_node(sqrt_node, graph_idx) + + # --------------------------- + # Step 5: Div (x / sqrt(...)) + # --------------------------- + div_name = self.create_new_name(node_name, op_type, "Div") + div_out = f"{div_name}_out" + div_node = onnx.helper.make_node( + "Div", + inputs=[ln_input, sqrt_out], + outputs=[div_out], + name=div_name, + ) + dag.add_node(div_node, graph_idx) + + # --------------------------- + # Step 6: Mul with gamma + # --------------------------- + mul_name = self.create_new_name(node_name, op_type, "Mul") + mul_out = f"{mul_name}_out" + mul_node = onnx.helper.make_node( + "Mul", + inputs=[div_out, gamma], + outputs=[mul_out], + name=mul_name, + ) + dag.add_node(mul_node, graph_idx) + + # --------------------------- + # Rewire consumers of the original main output + # --------------------------- + for consumer in dag.get_consumers(ln_output): + dag.replace_node_input(consumer, ln_output, mul_out) + + # --------------------------- + # For SkipSimplifiedLayerNormalization that had two outputs: + # - Output 1 is typically residual sum (input_skip_bias_sum) + # - Redirect its consumers to the skip-sum Add output + # --------------------------- + if op_type == "SkipSimplifiedLayerNormalization" and len(outputs) == 2: + second_output = outputs[1] + + second_vi = dag.get_value_info_proto(second_output) + if second_vi is not None: + new_vi = onnx.ValueInfoProto() + new_vi.CopyFrom(second_vi) + new_vi.name = skip_add_out + dag.add_value_info(new_vi, graph_idx) + + # Redirect all consumers of the second output + for consumer in dag.get_consumers(second_output): + dag.replace_node_input(consumer, second_output, skip_add_out) + + dag.remove_node(node_name) + modified += 1 + + if modified > 0: + logger.debug( + "Replaced %d Simplified/SkipSimplifiedLayerNormalization nodes with RMSNorm subgraphs", modified + ) + + dag.update() + return dag.model + + class SimplifiedLayerNormToL2Norm(ProtoSurgeon): """Replace Skip/SimplifiedLayerNormalization node with L2Norm subgraph.