Skip to content

Commit dd14682

Browse files
authored
[torchlib] Fix operator add (#2630)
Operator add may take in python scalar sometimes (and doesn't have dtype). This PR splits the implementation out so that it is handled differently. Signed-off-by: Justin Chu <[email protected]>
1 parent 32a61f4 commit dd14682

File tree

1 file changed

+11
-1
lines changed
  • onnxscript/function_libs/torch_lib/ops

1 file changed

+11
-1
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
148148
return op.Add(self, other)
149149

150150

151+
@torch_op(("_operator::add"), trace_only=True)
152+
def operator_add(self: TTensor, other: TTensor) -> TTensor:
153+
return op.Add(self, other)
154+
155+
151156
@torch_op(("aten::add.Tensor", "aten::add.Scalar"), trace_only=True, complex=True)
152157
def aten_add_complex(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
153158
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
@@ -5567,7 +5572,7 @@ def aten_msort(self: TensorType) -> TensorType:
55675572

55685573

55695574
@torch_op(
5570-
("aten::mul", "aten::mul.Tensor", "_operator::mul", "aten::multiply.Tensor"),
5575+
("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"),
55715576
trace_only=True,
55725577
)
55735578
def aten_mul(self: TTensor, other: TTensor) -> TTensor:
@@ -5579,6 +5584,11 @@ def aten_mul(self: TTensor, other: TTensor) -> TTensor:
55795584
return op.Mul(self, other)
55805585

55815586

5587+
@torch_op("_operator::mul", trace_only=True)
5588+
def operator_mul(self: TTensor, other: TTensor) -> TTensor:
5589+
return op.Mul(self, other)
5590+
5591+
55825592
@torch_op(
55835593
("aten::mul", "aten::mul.Tensor", "aten::multiply.Tensor"),
55845594
trace_only=True,

0 commit comments

Comments
 (0)