Skip to content

Commit 45a2a40

Browse files
authored
[Relax][PyTorch] Add decomposed operator support for Binary (#18458)
## Related Issue - #18401 ## Why - When `run_ep_decomposition=True` is enabled, PyTorch decomposes binary operators into lower-level operations and some of them are not supported, which cause error ## How - Added support for `bitwise_and.Tensor`, `bitwise_and.Scalar`, `bitwise_xor.Tensor` and `bitwise_xor.Scalar` - Updated `test_binary` to use `run_ep_decomposition=True`
1 parent b6ac072 commit 45a2a40

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,8 +898,12 @@ def create_convert_map(
898898
# binary
899899
"add.Tensor": self._binary_op(relax.op.add, operator.add),
900900
"add_.Tensor": self._binary_op(relax.op.add, operator.add),
901+
"bitwise_and.Tensor": self._binary_op(relax.op.bitwise_and, operator.and_),
902+
"bitwise_and.Scalar": self._binary_op(relax.op.bitwise_and, operator.and_),
901903
"bitwise_or_.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
902904
"bitwise_or.Scalar": self._binary_op(relax.op.bitwise_or, operator.or_),
905+
"bitwise_xor.Tensor": self._binary_op(relax.op.bitwise_xor, operator.xor),
906+
"bitwise_xor.Scalar": self._binary_op(relax.op.bitwise_xor, operator.xor),
903907
"bitwise_or_.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
904908
"bitwise_or.Tensor": self._binary_op(relax.op.bitwise_or, operator.or_),
905909
"div.Scalar": self._binary_op(relax.op.divide, operator.truediv),
@@ -929,6 +933,8 @@ def create_convert_map(
929933
"min.other": self._binary_op(relax.op.minimum, min),
930934
"max.default": self._unary_op(relax.op.max),
931935
"min.default": self._unary_op(relax.op.min),
936+
"maximum.default": self._binary_op(relax.op.maximum, torch.maximum),
937+
"minimum.default": self._binary_op(relax.op.minimum, torch.minimum),
932938
"remainder.Tensor": self._binary_op(relax.op.floor_mod, operator.mod),
933939
"remainder.Scalar": self._binary_op(relax.op.floor_mod, operator.mod),
934940
"mul.Tensor": self._binary_op(relax.op.multiply, operator.mul),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,21 @@ def main(
12911291
R.output(gv)
12921292
return gv
12931293

1294+
@tvm.script.ir_module
1295+
class expected_binary1_inplace:
1296+
@R.function
1297+
def main(
1298+
lhs: R.Tensor((10, 10), dtype="float32"),
1299+
rhs: R.Tensor((10, 10), dtype="float32"),
1300+
) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")):
1301+
with R.dataflow():
1302+
lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, rhs)
1303+
gv: R.Tuple(
1304+
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")
1305+
) = (lv, lv)
1306+
R.output(gv)
1307+
return gv
1308+
12941309
class Binary2(Module):
12951310
def __init__(self, op):
12961311
super().__init__()
@@ -1311,8 +1326,30 @@ def main(
13111326
R.output(gv)
13121327
return gv
13131328

1314-
verify_model(Binary1(op), example_args1, {}, expected_binary1)
1315-
verify_model(Binary2(op), example_args2, {}, expected_binary2)
1329+
@tvm.script.ir_module
1330+
class expected_binary2_inplace:
1331+
@R.function
1332+
def main(
1333+
lhs: R.Tensor((10, 10), dtype="float32"),
1334+
) -> R.Tuple(R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")):
1335+
with R.dataflow():
1336+
lv: R.Tensor((10, 10), dtype="float32") = relax_op(lhs, R.const(1.0))
1337+
gv: R.Tuple(
1338+
R.Tensor((10, 10), dtype="float32"), R.Tensor((10, 10), dtype="float32")
1339+
) = (lv, lv)
1340+
R.output(gv)
1341+
return gv
1342+
1343+
inplace_ops = [
1344+
torch.ops.aten.add_,
1345+
torch.ops.aten.bitwise_or_,
1346+
torch.ops.aten.mul_,
1347+
]
1348+
1349+
expected1 = expected_binary1_inplace if op in inplace_ops else expected_binary1
1350+
expected2 = expected_binary2_inplace if op in inplace_ops else expected_binary2
1351+
verify_model(Binary1(op), example_args1, {}, expected1, run_ep_decomposition=True)
1352+
verify_model(Binary2(op), example_args2, {}, expected2, run_ep_decomposition=True)
13161353

13171354

13181355
operator_binary_2 = [
@@ -1374,8 +1411,8 @@ def main(
13741411
R.output(gv)
13751412
return gv
13761413

1377-
verify_model(Binary1(op), example_args1, {}, expected_binary1)
1378-
verify_model(Binary2(op), example_args2, {}, expected_binary2)
1414+
verify_model(Binary1(op), example_args1, {}, expected_binary1, run_ep_decomposition=True)
1415+
verify_model(Binary2(op), example_args2, {}, expected_binary2, run_ep_decomposition=True)
13791416

13801417

13811418
def test_binary3():
@@ -1403,7 +1440,7 @@ def main(
14031440
R.output(gv)
14041441
return gv
14051442

1406-
verify_model(Max1(), example_args1, {}, expected_max1)
1443+
verify_model(Max1(), example_args1, {}, expected_max1, run_ep_decomposition=True)
14071444

14081445
# Min
14091446
class Min1(Module):
@@ -1423,7 +1460,7 @@ def main(
14231460
R.output(gv)
14241461
return gv
14251462

1426-
verify_model(Min1(), example_args1, {}, expected_min1)
1463+
verify_model(Min1(), example_args1, {}, expected_min1, run_ep_decomposition=True)
14271464

14281465
# RSub
14291466
class RSub1(Module):
@@ -1458,8 +1495,8 @@ def main(
14581495
R.output(gv)
14591496
return gv
14601497

1461-
verify_model(RSub1(), example_args1, {}, expected_rsub1)
1462-
verify_model(RSub2(), example_args2, {}, expected_rsub2)
1498+
verify_model(RSub1(), example_args1, {}, expected_rsub1, run_ep_decomposition=True)
1499+
verify_model(RSub2(), example_args2, {}, expected_rsub2, run_ep_decomposition=True)
14631500

14641501

14651502
# IsIn

0 commit comments

Comments
 (0)