Skip to content

Commit df8f706

Browse files
justinchubyCopilot
andauthored
[torchlib] Support integers in logical_and/or ops and update other logical ops (#2582)
This PR 1. Consolidates logic for `bitwise_*` functions so that the `logical_*` functions are no longer handling bool overloads of the bitwise ops. 2. Adds support for integer inputs in the `logical_*` implementations. Replacement of #2579. --------- Signed-off-by: Justin Chu <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent dddf0c2 commit df8f706

File tree

2 files changed

+67
-54
lines changed

2 files changed

+67
-54
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,15 @@ def aten_acosh(self: TFloat) -> TFloat:
162162

163163

164164
@torch_op(("aten::add.Tensor", "aten::add.Scalar", "_operator::add"), trace_only=True)
165-
def aten_add(self: TReal, other: TReal, alpha: float = 1.0) -> TReal:
165+
def aten_add(self: TTensor, other: TTensor, alpha: float = 1.0) -> TTensor:
166166
"""add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor"""
167-
# TODO(microsoft/onnxruntime#15977): Improve fp16 precision
167+
168+
if self.dtype == ir.DataType.BOOL:
169+
# alpha can also be bool
170+
if alpha == 0:
171+
return op.Identity(self)
172+
return op.Or(self, other)
173+
168174
if alpha != 1.0:
169175
alpha = op.CastLike(alpha, other)
170176
other = op.Mul(other, alpha)
@@ -1233,15 +1239,19 @@ def aten_binomial(
12331239
"aten::bitwise_and.Tensor",
12341240
"aten::bitwise_and.Scalar",
12351241
"aten::bitwise_and.Scalar_Tensor",
1236-
"_operator::and_",
12371242
),
12381243
trace_only=True,
12391244
)
1240-
def aten_bitwise_and(self: TInt, other: TInt) -> TInt:
1245+
def aten_bitwise_and(self: TTensor, other: TTensor) -> TTensor:
12411246
"""bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor"""
1242-
# logical_and implements the BOOL variant
12431247

1244-
return op.BitwiseAnd(self, other)
1248+
assert self.dtype == other.dtype
1249+
1250+
if self.dtype.is_integer():
1251+
return op.BitwiseAnd(self, other)
1252+
if self.dtype == ir.DataType.BOOL:
1253+
return op.And(self, other)
1254+
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
12451255

12461256

12471257
@torch_op(
@@ -1329,27 +1339,34 @@ def aten_bitwise_left_shift_int8(self: INT8, other: INT8) -> INT8:
13291339

13301340

13311341
@torch_op("aten::bitwise_not", trace_only=True)
1332-
def aten_bitwise_not(self: TInt) -> TInt:
1342+
def aten_bitwise_not(self: TTensor) -> TTensor:
13331343
"""bitwise_not(Tensor self) -> Tensor"""
1334-
# logical_not implements the BOOL variant
13351344

1336-
return op.BitwiseNot(self)
1345+
if self.dtype == ir.DataType.BOOL:
1346+
return op.Not(self)
1347+
if self.dtype.is_integer():
1348+
return op.BitwiseNot(self)
1349+
raise NotImplementedError(f"Not implemented for type {self.dtype}")
13371350

13381351

13391352
@torch_op(
13401353
(
13411354
"aten::bitwise_or.Tensor",
13421355
"aten::bitwise_or.Scalar",
13431356
"aten::bitwise_or.Scalar_Tensor",
1344-
"_operator::or_",
13451357
),
13461358
trace_only=True,
13471359
)
1348-
def aten_bitwise_or(self: TInt, other: TInt) -> TInt:
1360+
def aten_bitwise_or(self: TTensor, other: TTensor) -> TTensor:
13491361
"""bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor"""
1350-
# logical_or implements the BOOL variant
13511362

1352-
return op.BitwiseOr(self, other)
1363+
assert self.dtype == other.dtype
1364+
1365+
if self.dtype.is_integer():
1366+
return op.BitwiseOr(self, other)
1367+
if self.dtype == ir.DataType.BOOL:
1368+
return op.Or(self, other)
1369+
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
13531370

13541371

13551372
@torch_op(
@@ -1487,11 +1504,15 @@ def aten_bitwise_right_shift_int8(self: INT8, other: INT8) -> INT8:
14871504
),
14881505
trace_only=True,
14891506
)
1490-
def aten_bitwise_xor(self: TInt, other: TInt) -> TInt:
1507+
def aten_bitwise_xor(self: TTensor, other: TTensor) -> TTensor:
14911508
"""bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor"""
1492-
# logical_xor implements the BOOL variant
1509+
assert self.dtype == other.dtype
14931510

1494-
return op.BitwiseXor(self, other)
1511+
if self.dtype.is_integer():
1512+
return op.BitwiseXor(self, other)
1513+
if self.dtype == ir.DataType.BOOL:
1514+
return op.Xor(self, other)
1515+
raise NotImplementedError(f"Not implemented for types {self.dtype} and {other.dtype}")
14951516

14961517

14971518
@torch_op("aten::blackman_window", trace_only=True)
@@ -5010,58 +5031,46 @@ def aten_logdet(self: TFloat) -> TFloat:
50105031
return op.Log(op.Det(self))
50115032

50125033

5013-
@torch_op(
5014-
(
5015-
"aten::logical_and",
5016-
"aten::bitwise_and.Tensor",
5017-
"aten::bitwise_and.Scalar",
5018-
"aten::bitwise_and.Scalar_Tensor",
5019-
),
5020-
trace_only=True,
5021-
)
5022-
def aten_logical_and(self: BOOL, other: BOOL) -> BOOL:
5034+
@torch_op("aten::logical_and", trace_only=True)
5035+
def aten_logical_and(self: TTensor, other: TTensor) -> BOOL:
50235036
"""logical_and(Tensor self, Tensor other) -> Tensor"""
50245037

5025-
return op.And(self, other)
5038+
assert self.dtype == other.dtype
5039+
5040+
if self.dtype == ir.DataType.BOOL:
5041+
return op.And(self, other)
5042+
return op.And(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
50265043

50275044

5028-
@torch_op(("aten::logical_not", "aten::bitwise_not"), trace_only=True)
5029-
def aten_logical_not(self: BOOL) -> BOOL:
5045+
@torch_op("aten::logical_not", trace_only=True)
5046+
def aten_logical_not(self: TTensor) -> BOOL:
50305047
"""logical_not(Tensor self) -> Tensor"""
50315048

5032-
return op.Not(self)
5049+
if self.dtype == ir.DataType.BOOL:
5050+
return op.Not(self)
5051+
return op.Not(op.Cast(self, to=BOOL.dtype))
50335052

50345053

5035-
@torch_op(
5036-
(
5037-
"aten::logical_or",
5038-
"aten::bitwise_or.Tensor",
5039-
"aten::bitwise_or.Scalar",
5040-
"aten::bitwise_or.Scalar_Tensor",
5041-
"aten::add.Tensor",
5042-
"aten::add.Scalar",
5043-
),
5044-
trace_only=True,
5045-
)
5046-
def aten_logical_or(self: BOOL, other: BOOL) -> BOOL:
5054+
@torch_op(("aten::logical_or"), trace_only=True)
5055+
def aten_logical_or(self: TTensor, other: TTensor) -> BOOL:
50475056
"""logical_or(Tensor self, Tensor other) -> Tensor"""
50485057

5049-
return op.Or(self, other)
5058+
assert self.dtype == other.dtype
50505059

5060+
if self.dtype == ir.DataType.BOOL:
5061+
return op.Or(self, other)
5062+
return op.Or(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
50515063

5052-
@torch_op(
5053-
(
5054-
"aten::logical_xor",
5055-
"aten::bitwise_xor.Tensor",
5056-
"aten::bitwise_xor.Scalar",
5057-
"aten::bitwise_xor.Scalar_Tensor",
5058-
),
5059-
trace_only=True,
5060-
)
5061-
def aten_logical_xor(self: BOOL, other: BOOL) -> BOOL:
5064+
5065+
@torch_op("aten::logical_xor", trace_only=True)
5066+
def aten_logical_xor(self: TTensor, other: TTensor) -> BOOL:
50625067
"""logical_xor(Tensor self, Tensor other) -> Tensor"""
50635068

5064-
return op.Xor(self, other)
5069+
assert self.dtype == other.dtype
5070+
5071+
if self.dtype == ir.DataType.BOOL:
5072+
return op.Xor(self, other)
5073+
return op.Xor(op.Cast(self, to=BOOL.dtype), op.Cast(other, to=BOOL.dtype))
50655074

50665075

50675076
@torch_op("aten::logit", private=True)

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,10 @@ def _where_input_wrangler(
16311631
dtypes=(torch.float32 if sys.platform != "linux" else torch.complex64,),
16321632
reason="fixme: test is unstable on macosx, windows",
16331633
),
1634+
TorchLibOpInfo("logical_and", core_ops.aten_logical_and),
1635+
TorchLibOpInfo("logical_not", core_ops.aten_logical_not),
1636+
TorchLibOpInfo("logical_or", core_ops.aten_logical_or),
1637+
TorchLibOpInfo("logical_xor", core_ops.aten_logical_xor),
16341638
TorchLibOpInfo("logit", core_ops.aten_logit, tolerance={torch.float16: (1e-1, 7e-4)}),
16351639
TorchLibOpInfo("max_dim", core_ops.aten_max_dim)
16361640
.xfail(

0 commit comments

Comments
 (0)