diff --git a/exir/passes/remove_mixed_type_operators.py b/exir/passes/remove_mixed_type_operators.py index 701a8269f10..d0e48a277c0 100644 --- a/exir/passes/remove_mixed_type_operators.py +++ b/exir/passes/remove_mixed_type_operators.py @@ -23,12 +23,20 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901 promotion_type_allow_list = { torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, - torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + # The correct promotion for div depends on the mode! If there is no mode, + # it's INT_TO_FLOAT, otherwise it's default. + torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + torch.ops.aten.div.Tensor_mode: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, } if op in promotion_type_allow_list: promotion_kind = promotion_type_allow_list[op] + if ( + op == torch.ops.aten.div.Tensor_mode + and kwargs.get("rounding_mode") is None + ): + promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT else: # Not in allow list, do nothing return super().call_operator(op, args, kwargs, meta) diff --git a/exir/tests/test_passes.py b/exir/tests/test_passes.py index e83a4cb7e50..a9dabad6234 100644 --- a/exir/tests/test_passes.py +++ b/exir/tests/test_passes.py @@ -9,7 +9,7 @@ import os import tempfile import unittest -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import executorch.exir as exir @@ -71,6 +71,7 @@ from functorch.experimental import control_flow from torch import nn +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.export import export from torch.export.graph_signature import InputKind, InputSpec, TensorArgument from torch.fx import GraphModule, subgraph_rewriter @@ -121,39 +122,97 @@ def foo_out( return a + 1, None +def simple_promote_dtype( + dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND +) -> torch.dtype: + if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: + return dtype + if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: + return dtype if dtype.is_floating_point else torch.float + else: + raise Exception(f"Unsupported promotion kind {promotion_kind}") + + +def count_nodes_with_target_asserting_arguments_have_dtype( + self, module, target, arg_dtype +) -> int: + count = 0 + for node in module.graph.nodes: + if node.op == "call_function" and node.target == target: + count += 1 + for arg in node.args: + self.assertEqual(arg.meta["val"].dtype, arg_dtype) + return count + + class TestPasses(unittest.TestCase): @classmethod def setUpClass(cls) -> None: register_additional_test_aten_ops() def test_remove_mixed_type_operators(self) -> None: - def count_nodes_with_target_asserting_arguments_have_dtype( - new_graph_module, target, arg_dtype - ): - count = 0 - for node in new_graph_module.graph.nodes: - if node.op == "call_function" and node.target == target: - count += 1 - for arg in node.args: - self.assertEqual(arg.meta["val"].dtype, arg_dtype) - return count - - class Add(torch.nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return (x + y) + x - - class Mult(torch.nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x * y - - class Minimum(torch.nn.Module): - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return torch.minimum(x, y) + def make_module(fwd: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): + class Module(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return fwd(x, y) + + return Module + + Add = make_module(lambda x, y: (x + y) + x) + Mult = make_module(lambda x, y: x * y) + Minimum = make_module(torch.minimum) + DivWithoutMode = make_module(torch.div) + DivWithNoneMode = make_module(lambda x, y: torch.div(x, y, rounding_mode=None)) + DivWithTruncMode = make_module( + lambda x, y: torch.div(x, y, rounding_mode="trunc") + ) + DivWithFloorMode = make_module( + lambda x, y: torch.div(x, y, rounding_mode="floor") + ) - for module, op, expected_count in ( - (Add, exir_ops.edge.aten.add.Tensor, 2), - (Mult, exir_ops.edge.aten.mul.Tensor, 1), - (Minimum, exir_ops.edge.aten.minimum.default, 1), + for module, op, expected_count, promotion_kind in ( + ( + Add, + exir_ops.edge.aten.add.Tensor, + 2, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ( + Mult, + exir_ops.edge.aten.mul.Tensor, + 1, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ( + Minimum, + exir_ops.edge.aten.minimum.default, + 1, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ( + DivWithoutMode, + exir_ops.edge.aten.div.Tensor, + 1, + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ), + ( + DivWithNoneMode, + exir_ops.edge.aten.div.Tensor_mode, + 1, + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, + ), + ( + DivWithTruncMode, + exir_ops.edge.aten.div.Tensor_mode, + 1, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), + ( + DivWithFloorMode, + exir_ops.edge.aten.div.Tensor_mode, + 1, + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ), ): for second_arg_dtype in (torch.int64, torch.float, torch.double): int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64) @@ -166,8 +225,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: new_graph_module = new_prog.exported_program().graph_module self.assertIsNotNone(new_graph_module) + promoted_type = simple_promote_dtype(second_arg_dtype, promotion_kind) count = count_nodes_with_target_asserting_arguments_have_dtype( - new_graph_module, op, second_arg_dtype + self, new_graph_module, op, promoted_type ) self.assertEqual(count, expected_count)