Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion exir/passes/remove_mixed_type_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,20 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901
promotion_type_allow_list = {
torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
# The correct promotion for div depends on the mode! If there is no mode,
# it's INT_TO_FLOAT, otherwise it's default.
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
torch.ops.aten.div.Tensor_mode: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
}

if op in promotion_type_allow_list:
promotion_kind = promotion_type_allow_list[op]
if (
op == torch.ops.aten.div.Tensor_mode
and kwargs.get("rounding_mode") is None
):
promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
Comment on lines +35 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying div.Tensor_mode without rounding_mode specified is equivalent to div.Tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
# Not in allow list, do nothing
return super().call_operator(op, args, kwargs, meta)
Expand Down
178 changes: 101 additions & 77 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import tempfile
import unittest
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

import executorch.exir as exir

Expand Down Expand Up @@ -71,6 +71,7 @@
from functorch.experimental import control_flow

from torch import nn
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.export import export
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
from torch.fx import GraphModule, subgraph_rewriter
Expand Down Expand Up @@ -121,91 +122,114 @@ def foo_out(
return a + 1, None


def simple_promote_dtype(
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
) -> torch.dtype:
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
return dtype
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
return dtype if dtype.is_floating_point else torch.float
else:
raise Exception(f"Unsupported promotion kind {promotion_kind}")


def count_nodes_with_target_asserting_arguments_have_dtype(
self, module, target, arg_dtype
) -> int:
count = 0
for node in module.graph.nodes:
if node.op == "call_function" and node.target == target:
count += 1
for arg in node.args:
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
return count


class TestPasses(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
register_additional_test_aten_ops()

def test_remove_mixed_type_operators(self) -> None:
class Add(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return (x + y) + x

add = Add()

int_tensor = torch.tensor([[1, 2, 3]])
float_tensor = torch.tensor([[1.0, 2.0, 3.0]])
edge_prog = to_edge(export(add, (int_tensor, float_tensor), strict=True))

new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
new_graph_module = new_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module)

add_count = 0

for node in new_graph_module.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
):
add_count += 1
node_args = node.args
for arg in node_args:
self.assertEqual(arg.meta["val"].dtype, torch.float)

self.assertEqual(add_count, 2)

double_tensor = torch.tensor([[1.0, 2.0, 3.0]])
double_tensor = double_tensor.to(torch.double)

double_prog = to_edge(export(add, (int_tensor, double_tensor), strict=True))

double_prog.transform([RemoveMixedTypeOperators()])
new_graph_module_double = double_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module_double)

add_count_double = 0

for node in new_graph_module_double.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
):
add_count_double += 1
node_args = node.args
for arg in node_args:
self.assertEqual(arg.meta["val"].dtype, torch.double)

self.assertEqual(add_count_double, 2)

class Mult(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y

mult = Mult()

float_tensor_vert = float_tensor.T
mult_prog = to_edge(export(mult, (int_tensor, float_tensor_vert), strict=True))

# graph_module_mult.graph.print_tabular()
def make_module(fwd: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]):
class Module(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return fwd(x, y)
return Module

Add = make_module(lambda x, y: (x + y) + x)
Mult = make_module(lambda x, y: x * y)
Minimum = make_module(torch.minimum)
DivWithoutMode = make_module(torch.div)
DivWithNoneMode = make_module(lambda x, y: torch.div(x, y, rounding_mode=None))
DivWithTruncMode = make_module(
lambda x, y: torch.div(x, y, rounding_mode="trunc")
)
DivWithFloorMode = make_module(
lambda x, y: torch.div(x, y, rounding_mode="floor")
)

mult_prog = mult_prog.transform([RemoveMixedTypeOperators()])
new_graph_module_mult = mult_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module_mult)
ETPK = ELEMENTWISE_TYPE_PROMOTION_KIND
for module, op, expected_count, promotion_kind in (
(
Add,
exir_ops.edge.aten.add.Tensor,
2,
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
),
(
Mult,
exir_ops.edge.aten.mul.Tensor,
1,
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
),
(
Minimum,
exir_ops.edge.aten.minimum.default,
1,
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
),
(
DivWithoutMode,
exir_ops.edge.aten.div.Tensor,
1,
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
),
(
DivWithNoneMode,
exir_ops.edge.aten.div.Tensor_mode,
1,
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
),
(
DivWithTruncMode,
exir_ops.edge.aten.div.Tensor_mode,
1,
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
),
(
DivWithFloorMode,
exir_ops.edge.aten.div.Tensor_mode,
1,
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
),
):
for second_arg_dtype in (torch.int64, torch.float, torch.double):
int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64)
float_tensor = torch.tensor([[1.0, 2.0, 3.0]], dtype=second_arg_dtype)
edge_prog = to_edge(
export(module(), (int_tensor, float_tensor), strict=True)
)

mult_count = 0
new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
new_graph_module = new_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module)

for node in new_graph_module_mult.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.mul.Tensor
):
mult_count += 1
node_args = node.args
for arg in node_args:
self.assertEqual(arg.meta["val"].dtype, torch.float)

self.assertEqual(mult_count, 1)
promoted_type = simple_promote_dtype(second_arg_dtype, promotion_kind)
count = count_nodes_with_target_asserting_arguments_have_dtype(
self, new_graph_module, op, promoted_type
)
self.assertEqual(count, expected_count)

def test_remove_noop_pass(self) -> None:
class Foo(torch.nn.Module):
Expand Down
Loading