Skip to content

Commit 263da9e

Browse files
authored
Merge branch 'main' into titaiwang/bump_version
2 parents 122c426 + 8089bc7 commit 263da9e

File tree

3 files changed

+30
-18
lines changed

3 files changed

+30
-18
lines changed

onnxscript/optimizer/_constant_folding.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,6 @@ def cast(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
496496
if input is None or output is None:
497497
return None
498498

499-
# TODO(rama): Parts of the following logic (implementing type/shape inference
500-
# for Cast op) should be unnecessary. Generic incremental shape-inference
501-
# should handle this. Only the optimization to eliminate redundant Cast ops
502-
# should be needed here.
503-
504-
output.shape = _merge_shapes(output.shape, input.shape)
505-
506499
input_dtype = _get_input_element_type(node, 0)
507500
output_dtype = _get_int_attribute(node, "to", None)
508501
if output_dtype is not None:
@@ -608,6 +601,7 @@ def identity(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
608601
input = node.inputs[0]
609602
output = node.outputs[0]
610603
if input is not None and output is not None:
604+
# NOTE: backward shape inference
611605
input.shape = _merge_shapes(input.shape, output.shape)
612606
if input.type is None:
613607
input.type = output.type
@@ -904,7 +898,11 @@ def sequence_at(node: ir.Node, op, state: OptimizerState) -> ReturnValue:
904898
return None
905899

906900

907-
def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
901+
def _merge_shapes(
902+
preferred_shape: ir.Shape | None, other_shape: ir.Shape | None
903+
) -> ir.Shape | None:
904+
"""Merge two shapes, preferring dimensions from preferred_shapes."""
905+
908906
def merge_dims(dim1, dim2):
909907
if dim1 == dim2:
910908
return dim1
@@ -916,13 +914,15 @@ def merge_dims(dim1, dim2):
916914
return dim2
917915
return dim1
918916

919-
if shape1 is None:
920-
return shape2
921-
if shape2 is None:
922-
return shape1
923-
if len(shape1) != len(shape2):
917+
if preferred_shape is None:
918+
return other_shape
919+
if other_shape is None:
920+
return preferred_shape
921+
if len(preferred_shape) != len(other_shape):
924922
raise ValueError("Shapes must have the same rank.")
925-
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
923+
return ir.Shape(
924+
[merge_dims(dim1, dim2) for dim1, dim2 in zip(preferred_shape, other_shape)]
925+
)
926926

927927

928928
def _record_contributing_values(original_node: ir.Node, replacement: Replacement) -> None:
@@ -1029,6 +1029,7 @@ def get_type(value: ir.Value) -> onnx.TypeProto | None:
10291029
inferred_shape = ir.serde.deserialize_type_proto_for_shape(
10301030
inferred_type
10311031
)
1032+
# NOTE: forward shape inference
10321033
output.shape = _merge_shapes(output.shape, inferred_shape)
10331034
output.type = ir.serde.deserialize_type_proto_for_type(inferred_type)
10341035
except Exception as e:

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131

3232

3333
class RmsNormFusion(pattern.RewriteRuleClassBase):
34+
def __init__(self, name: str, _mul_order: bool):
35+
super().__init__(name)
36+
self._mul_order = _mul_order
37+
3438
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
3539
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
3640
x_square = op.Pow(x, 2.0)
@@ -42,7 +46,11 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
4246
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
4347
# To support float16, we need to ensure the scale is casted or not.
4448
scale = pattern.OrValue([op.Cast(scale, to=compute_dtype), scale])
45-
return op.Mul(scale, normalized)
49+
# Workaround: can't use OrValue for final (returned) value
50+
if self._mul_order:
51+
return op.Mul(normalized, scale)
52+
else:
53+
return op.Mul(scale, normalized)
4654

4755
def check(
4856
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
@@ -77,8 +85,10 @@ def rewrite(self, op, x, scale, epsilon, **_):
7785
)
7886

7987

80-
_rule = RmsNormFusion.rule()
81-
rms_normalization_rules = [_rule]
88+
_rule1 = RmsNormFusion.rule("RmsNormFusion1", _mul_order=False)
89+
_rule2 = RmsNormFusion.rule("RmsNormFusion2", _mul_order=True)
90+
91+
rms_normalization_rules = [_rule1, _rule2]
8292
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)
8393

8494

onnxscript/rewriter/rules/fusion/_gqa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def pattern(
5252
_outputs=["attention_BHSDh"],
5353
)
5454

55-
return attention_BHSDh
55+
return attention_BHSDh, present_key_BHkvStD, present_value_BHkvStD
5656

5757
def check(
5858
self,
@@ -103,6 +103,7 @@ def rewrite(
103103
past_key_BHkvSpD,
104104
past_value_BHkvSpD,
105105
**original_attrs,
106+
_outputs=3,
106107
)
107108

108109

0 commit comments

Comments
 (0)