Skip to content

Commit 91c1921

Browse files
authored
[Relax][PyTorch] Add binary operation dtype promotion following PyTorch rules in ExportedProgram frontend (apache#18497)
As per title. ref: https://docs.pytorch.org/docs/stable/generated/torch.promote_types.html
1 parent 0bd6f9c commit 91c1921

File tree

2 files changed

+102
-0
lines changed

2 files changed

+102
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@ def shape_of(tensor):
8888
return tensor.shape
8989
raise ValueError("Unsupported type: {}".format(type(tensor)))
9090

91+
@staticmethod
92+
def _promote_common_dtype(lhs_dtype: Optional[str], rhs_dtype: Optional[str]) -> Optional[str]:
93+
"""Return the promoted dtype following PyTorch rules, or None if unsupported."""
94+
import torch # type: ignore
95+
96+
if lhs_dtype is None or rhs_dtype is None or lhs_dtype == rhs_dtype:
97+
return None
98+
99+
tvm_to_torch = {
100+
"float64": torch.float64,
101+
"float32": torch.float32,
102+
"float16": torch.float16,
103+
"bfloat16": torch.bfloat16,
104+
"int64": torch.int64,
105+
"int32": torch.int32,
106+
"int16": torch.int16,
107+
"int8": torch.int8,
108+
"uint8": torch.uint8,
109+
"bool": torch.bool,
110+
}
111+
torch_to_tvm = {v: k for k, v in tvm_to_torch.items()}
112+
113+
lhs_torch = tvm_to_torch.get(lhs_dtype)
114+
rhs_torch = tvm_to_torch.get(rhs_dtype)
115+
if lhs_torch is None or rhs_torch is None:
116+
return None
117+
118+
promoted = torch.promote_types(lhs_torch, rhs_torch)
119+
return torch_to_tvm.get(promoted, None)
120+
91121
@staticmethod
92122
def _is_no_bias(bias):
93123
"""Check if bias represents 'no bias' condition.
@@ -408,6 +438,17 @@ def _binary_op(self, relax_op: Callable, intrinsic_op: Callable) -> Callable:
408438
def convert(node: fx.Node) -> relax.Var:
409439
def promote_binary_op_args(lhs, rhs):
410440
if isinstance(lhs, relax.Expr) and isinstance(rhs, relax.Expr):
441+
lhs_si = getattr(lhs, "struct_info", None)
442+
rhs_si = getattr(rhs, "struct_info", None)
443+
if isinstance(lhs_si, relax.TensorStructInfo) and isinstance(
444+
rhs_si, relax.TensorStructInfo
445+
):
446+
target_dtype = self._promote_common_dtype(lhs_si.dtype, rhs_si.dtype)
447+
if target_dtype is not None:
448+
if lhs_si.dtype != target_dtype:
449+
lhs = self.block_builder.emit(relax.op.astype(lhs, target_dtype))
450+
if rhs_si.dtype != target_dtype:
451+
rhs = self.block_builder.emit(relax.op.astype(rhs, target_dtype))
411452
return lhs, rhs
412453
elif isinstance(lhs, relax.Expr):
413454
assert isinstance(lhs.struct_info, relax.TensorStructInfo)

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1383,6 +1383,67 @@ def main(
13831383
verify_model(Binary2(op), example_args2, {}, expected2)
13841384

13851385

1386+
operator_binary_promote = [
1387+
(operator.add, R.add),
1388+
(operator.sub, R.subtract),
1389+
(operator.mul, R.multiply),
1390+
(operator.truediv, R.divide),
1391+
(operator.pow, R.power),
1392+
(operator.mod, R.floor_mod),
1393+
]
1394+
1395+
1396+
@pytest.mark.parametrize("op, relax_op", operator_binary_promote)
1397+
def test_binary_dtype_promotion(op, relax_op):
1398+
"""Ensure binary ops promote differing dtypes following PyTorch rules."""
1399+
1400+
class BinaryPromoteLHS(Module):
1401+
def forward(self, x):
1402+
arange_val = torch.arange(x.shape[1]) # int64 by default
1403+
return op(x, arange_val)
1404+
1405+
@tvm.script.ir_module
1406+
class expected_promote_lhs:
1407+
@R.function
1408+
def main(
1409+
x: R.Tensor((2, 3), dtype="float32")
1410+
) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
1411+
with R.dataflow():
1412+
lv: R.Tensor((3,), dtype="int64") = R.arange(
1413+
R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64"
1414+
)
1415+
lv1: R.Tensor((3,), dtype="float32") = R.astype(lv, dtype="float32")
1416+
lv2: R.Tensor((2, 3), dtype="float32") = relax_op(x, lv1)
1417+
gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,)
1418+
R.output(gv)
1419+
return gv
1420+
1421+
class BinaryPromoteRHS(Module):
1422+
def forward(self, x):
1423+
arange_val = torch.arange(x.shape[1]) # int64 by default
1424+
return op(arange_val, x)
1425+
1426+
@tvm.script.ir_module
1427+
class expected_promote_rhs:
1428+
@R.function
1429+
def main(
1430+
x: R.Tensor((2, 3), dtype="float32")
1431+
) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
1432+
with R.dataflow():
1433+
lv: R.Tensor((3,), dtype="int64") = R.arange(
1434+
R.prim_value(0), R.prim_value(3), R.prim_value(1), dtype="int64"
1435+
)
1436+
lv1: R.Tensor((3,), dtype="float32") = R.astype(lv, dtype="float32")
1437+
lv2: R.Tensor((2, 3), dtype="float32") = relax_op(lv1, x)
1438+
gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv2,)
1439+
R.output(gv)
1440+
return gv
1441+
1442+
example_args = (torch.randn(2, 3, dtype=torch.float32),)
1443+
verify_model(BinaryPromoteLHS(), example_args, {}, expected_promote_lhs)
1444+
verify_model(BinaryPromoteRHS(), example_args, {}, expected_promote_rhs)
1445+
1446+
13861447
operator_binary_2 = [
13871448
(operator.eq, R.equal),
13881449
(operator.ne, R.not_equal),

0 commit comments

Comments
 (0)