Skip to content

Commit 051b30b

Browse files
committed
Simplify pattern
Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 53869c7 commit 051b30b

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

onnxscript/rewriter/rules/fusion/_rotary_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ def pattern(self, op, x, freqs, start1, end1, start2, end2, one1, one2):
4343
def check(self, op, x, start1, end1, start2, end2, one1, one2, **_) -> pattern.MatchResult: # type: ignore[name-defined]
4444
check_result = pattern.MatchResult()
4545

46-
if not _ir_utils.is_singleton_value(one1, 1, rank=(0, 1)):
47-
return check_result.fail("Unsqueeze axes is not [1] or 1", one1)
48-
if not _ir_utils.is_singleton_value(one2, 1, rank=(0, 1)):
49-
return check_result.fail("Unsqueeze axes is not [1] or 1", one2)
46+
if not _ir_utils.is_singleton_value(one1, 1):
47+
return check_result.fail("Unsqueeze axes is not [1]", one1)
48+
if not _ir_utils.is_singleton_value(one2, 1):
49+
return check_result.fail("Unsqueeze axes is not [1]", one2)
5050

5151
# x needs to be a 4D tensor with known last dimension size (== head_size) and known second dimension (num_heads)
5252
if x is None or x.shape is None or len(x.shape) != 4:

0 commit comments

Comments
 (0)