Skip to content

Commit 574b66a

Browse files
authored
Merge branch 'main' into conv_and_hardswish_fusion
2 parents 2f2204d + c219dce commit 574b66a

File tree

18 files changed

+946
-95
lines changed

18 files changed

+946
-95
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
graph,
3737
ir,
3838
)
39+
from onnxscript._internal import version_utils
3940
from onnxscript.function_libs.torch_lib.ops import common as common_ops
4041
from onnxscript.function_libs.torch_lib.registration import torch_op
4142
from onnxscript.function_libs.torch_lib.tensor_typing import (
@@ -1647,29 +1648,40 @@ def aten_choose_qparams_optimized(
16471648
raise NotImplementedError()
16481649

16491650

1650-
@torch_op("aten::chunk")
1651-
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1652-
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1653-
# This will create a Sequence of tensors
1654-
neg_1 = op.Constant(value_ints=[-1])
1655-
# Get size of specified dim
1656-
self_shape = op.Shape(self)
1657-
dim_size = op.Gather(self_shape, dim, axis=0)
1658-
# Compute size/chunk to get the number of data in one chunk
1659-
num_per_chunk = op.Div(dim_size, chunks)
1660-
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
1661-
1662-
# Compute real chunk number
1663-
num_chunk = op.Div(dim_size, num_per_chunk)
1664-
# Get something like [n, n, n, n, ...], total num_chunk
1665-
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
1666-
1667-
remainder = op.Mod(dim_size, num_per_chunk)
1668-
if remainder > 0: # type: ignore[operator]
1669-
# Append the remainder to the [n, n, n, n, ..., r]
1670-
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
1671-
1672-
return op.SplitToSequence(self, list_split, axis=dim)
1651+
if version_utils.torch_older_than("2.7.0"):
1652+
# PyTorch <2.7 does not support determining the number of outputs for the Split op
1653+
# https://github.com/pytorch/pytorch/commit/9a1eac6704671c72a2e85c9138db57eb3a80bfb6
1654+
@torch_op("aten::chunk")
1655+
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1656+
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1657+
# This will create a Sequence of tensors
1658+
neg_1 = op.Constant(value_ints=[-1])
1659+
# Get size of specified dim
1660+
self_shape = op.Shape(self)
1661+
dim_size = op.Gather(self_shape, dim, axis=0)
1662+
# Compute size/chunk to get the number of data in one chunk
1663+
num_per_chunk = op.Div(dim_size, chunks)
1664+
num_per_chunk = op.Cast(op.Mod(dim_size, chunks) > 0, to=INT64.dtype) + num_per_chunk # type: ignore[operator]
1665+
1666+
# Compute real chunk number
1667+
num_chunk = op.Div(dim_size, num_per_chunk)
1668+
# Get something like [n, n, n, n, ...], total num_chunk
1669+
list_split = op.Expand(num_per_chunk, op.Reshape(num_chunk, neg_1))
1670+
1671+
remainder = op.Mod(dim_size, num_per_chunk)
1672+
if remainder > 0: # type: ignore[operator]
1673+
# Append the remainder to the [n, n, n, n, ..., r]
1674+
list_split = op.Concat(list_split, op.Reshape(remainder, neg_1), axis=0)
1675+
1676+
return op.SplitToSequence(self, list_split, axis=dim)
1677+
else:
1678+
1679+
@torch_op("aten::chunk", trace_only=True)
1680+
def aten_chunk(self: TTensor, chunks: int, dim: int = 0) -> Sequence[TTensor]:
1681+
"""chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]"""
1682+
if chunks == 1:
1683+
return op.Identity(self)
1684+
return op.Split(self, axis=dim, num_outputs=chunks)
16731685

16741686

16751687
@torch_op("aten::clamp", trace_only=True)

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2037,7 +2037,8 @@ def _aten_scaled_dot_product_attention_no_mask_onnx(
20372037
op.MatMul(query_scaled, key_transposed_scaled),
20382038
axis=-1,
20392039
)
2040-
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
2040+
if dropout_p != 0:
2041+
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
20412042
return op.MatMul(attn_weight, value)
20422043

20432044

@@ -2076,7 +2077,14 @@ def _aten_scaled_dot_product_attention_bool_mask_onnx(
20762077
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
20772078
axis=-1,
20782079
)
2079-
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
2080+
# When using scaled dot product attention with a boolean mask, the softmax operation might return NaN values
2081+
# due to the presence of -inf in an entire row (padding tokens), resulting in 0/0 (NaN) in the softmax output.
2082+
# This is because there's no safe/masked softmax imp in ONNX, so we need to handle NaN values explicitly to match
2083+
# the behavior of PyTorch with boolean masks.
2084+
# Reference: https://github.com/pytorch/pytorch/issues/103749
2085+
attn_weight = op.Where(op.IsNaN(attn_weight), zero, attn_weight)
2086+
if dropout_p != 0:
2087+
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
20802088
return op.MatMul(attn_weight, value)
20812089

20822090

@@ -2111,7 +2119,8 @@ def _aten_scaled_dot_product_attention_float_mask_onnx(
21112119
op.Add(op.MatMul(query_scaled, key_transposed_scaled), attn_mask),
21122120
axis=-1,
21132121
)
2114-
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
2122+
if dropout_p != 0:
2123+
attn_weight, _ = op.Dropout(attn_weight, dropout_p)
21152124
return op.MatMul(attn_weight, value)
21162125

21172126

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
broadcast_to_matmul,
2828
cast_constant_of_shape,
2929
collapse_slices,
30+
fuse_pad_into_conv,
3031
fuse_relus_clips,
3132
no_op,
3233
pattern,
@@ -49,6 +50,7 @@
4950
*fuse_relus_clips.fuse_relus_clips_rules().rules,
5051
*basic_rules.basic_optimization_rules().rules,
5152
*redundant_scatter_nd.rules.rules,
53+
*fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules,
5254
)
5355

5456

onnxscript/rewriter/_pattern_ir.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,20 +76,33 @@ def __str__(self) -> str:
7676
class AttrPattern(Pattern[ir.Attr]):
7777
"""Base class for an attribute pattern. Matches any attribute value by default."""
7878

79-
def __init__(self, name: str | None):
79+
def __init__(self, name: str | None, *, can_match_none: bool = False):
8080
self._name = name
81+
self._can_match_none = can_match_none
8182

8283
@property
8384
def name(self) -> str | None:
8485
return self._name
8586

87+
@property
88+
def can_match_none(self) -> bool:
89+
"""Indicates whether this pattern can match a None attribute."""
90+
return self._can_match_none
91+
8692
def matches(self, attr: ir.Attr) -> bool:
8793
return True
8894

8995
def __str__(self) -> str:
9096
return self._name if self._name is not None else "anonymous:" + str(id(self))
9197

9298

99+
class AttrVar(AttrPattern):
100+
"""Represents a pattern variable used to match against attribute values."""
101+
102+
def __init__(self, name: str | None, *, can_match_none: bool = False):
103+
super().__init__(name, can_match_none=can_match_none)
104+
105+
93106
# TODO: Support tensors. Align with usage elsewhere.
94107
SupportedAttrTypes = Union[
95108
int,
@@ -129,11 +142,11 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) ->
129142
# annotations to distinguish between ValuePattern and AttrPattern, but forces users to
130143
# use these type annotations.
131144
# TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.)
132-
if value.can_match_none or value.check_method is not None:
145+
if value.check_method is not None:
133146
raise ValueError(
134-
"Pattern variables used in attributes must not have can_match_none or check_method set."
147+
"Pattern variables used in attributes must not have check_method set."
135148
)
136-
return AttrPattern(value.name)
149+
return AttrVar(value.name, can_match_none=value.can_match_none)
137150
if isinstance(value, (int, float, str)):
138151
return AttrConstantPattern(value)
139152
if isinstance(value, Sequence):
@@ -493,8 +506,9 @@ def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchRes
493506
for name, attr_pattern in self.attributes.items():
494507
attr_value = node.attributes.get(name)
495508
if attr_value is None:
496-
return match.fail(f"Attribute {name} not found in node.", node)
497-
if not attr_pattern.matches(attr_value):
509+
if not attr_pattern.can_match_none:
510+
return match.fail(f"Attribute {name} not found in node.", node)
511+
elif not attr_pattern.matches(attr_value):
498512
return match.fail(
499513
f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.",
500514
node,

0 commit comments

Comments
 (0)