Skip to content

Commit 5eddd77

Browse files
committed
fix type promotion for div in RemoveMixedTypeOperators
The promotion strategy is dependent on the rounding mode (see the div decomp in PyTorch https://github.com/pytorch/pytorch/blob/main/torch/_refs/__init__.py#L1214 and then the promotion annotation on each of the true_divide/trunc_divide/floor_divide functions itcalls). I had to restructure the test a bit more so that lint didn't complain it was too complex. ghstack-source-id: b75ace2 ghstack-comment-id: 3026116637 Pull-Request-resolved: #12157
1 parent f6bb143 commit 5eddd77

File tree

2 files changed

+97
-29
lines changed

2 files changed

+97
-29
lines changed

exir/passes/remove_mixed_type_operators.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ def call_operator(self, op, args, kwargs, meta: NodeMetadata): # noqa: C901
2323
promotion_type_allow_list = {
2424
torch.ops.aten.add.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2525
torch.ops.aten.mul.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26-
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
26+
# The correct promotion for div depends on the mode! If there is no mode,
27+
# it's INT_TO_FLOAT, otherwise it's default.
28+
torch.ops.aten.div.Tensor: ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
29+
torch.ops.aten.div.Tensor_mode: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2730
torch.ops.aten.minimum.default: ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
2831
}
2932

3033
if op in promotion_type_allow_list:
3134
promotion_kind = promotion_type_allow_list[op]
35+
if (
36+
op == torch.ops.aten.div.Tensor_mode
37+
and kwargs.get("rounding_mode") is None
38+
):
39+
promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
3240
else:
3341
# Not in allow list, do nothing
3442
return super().call_operator(op, args, kwargs, meta)

exir/tests/test_passes.py

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
import tempfile
1111
import unittest
12-
from typing import List, Optional, Tuple
12+
from typing import Callable, List, Optional, Tuple
1313

1414
import executorch.exir as exir
1515

@@ -71,6 +71,7 @@
7171
from functorch.experimental import control_flow
7272

7373
from torch import nn
74+
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
7475
from torch.export import export
7576
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
7677
from torch.fx import GraphModule, subgraph_rewriter
@@ -121,39 +122,97 @@ def foo_out(
121122
return a + 1, None
122123

123124

125+
def simple_promote_dtype(
126+
dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
127+
) -> torch.dtype:
128+
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
129+
return dtype
130+
if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
131+
return dtype if dtype.is_floating_point else torch.float
132+
else:
133+
raise Exception(f"Unsupported promotion kind {promotion_kind}")
134+
135+
136+
def count_nodes_with_target_asserting_arguments_have_dtype(
137+
self, module, target, arg_dtype
138+
) -> int:
139+
count = 0
140+
for node in module.graph.nodes:
141+
if node.op == "call_function" and node.target == target:
142+
count += 1
143+
for arg in node.args:
144+
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
145+
return count
146+
147+
124148
class TestPasses(unittest.TestCase):
125149
@classmethod
126150
def setUpClass(cls) -> None:
127151
register_additional_test_aten_ops()
128152

129153
def test_remove_mixed_type_operators(self) -> None:
130-
def count_nodes_with_target_asserting_arguments_have_dtype(
131-
new_graph_module, target, arg_dtype
132-
):
133-
count = 0
134-
for node in new_graph_module.graph.nodes:
135-
if node.op == "call_function" and node.target == target:
136-
count += 1
137-
for arg in node.args:
138-
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
139-
return count
140-
141-
class Add(torch.nn.Module):
142-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
143-
return (x + y) + x
144-
145-
class Mult(torch.nn.Module):
146-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
147-
return x * y
148-
149-
class Minimum(torch.nn.Module):
150-
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
151-
return torch.minimum(x, y)
154+
def make_module(fwd: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]):
155+
class Module(torch.nn.Module):
156+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
157+
return fwd(x, y)
158+
159+
return Module
160+
161+
Add = make_module(lambda x, y: (x + y) + x)
162+
Mult = make_module(lambda x, y: x * y)
163+
Minimum = make_module(torch.minimum)
164+
DivWithoutMode = make_module(torch.div)
165+
DivWithNoneMode = make_module(lambda x, y: torch.div(x, y, rounding_mode=None))
166+
DivWithTruncMode = make_module(
167+
lambda x, y: torch.div(x, y, rounding_mode="trunc")
168+
)
169+
DivWithFloorMode = make_module(
170+
lambda x, y: torch.div(x, y, rounding_mode="floor")
171+
)
152172

153-
for module, op, expected_count in (
154-
(Add, exir_ops.edge.aten.add.Tensor, 2),
155-
(Mult, exir_ops.edge.aten.mul.Tensor, 1),
156-
(Minimum, exir_ops.edge.aten.minimum.default, 1),
173+
for module, op, expected_count, promotion_kind in (
174+
(
175+
Add,
176+
exir_ops.edge.aten.add.Tensor,
177+
2,
178+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
179+
),
180+
(
181+
Mult,
182+
exir_ops.edge.aten.mul.Tensor,
183+
1,
184+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
185+
),
186+
(
187+
Minimum,
188+
exir_ops.edge.aten.minimum.default,
189+
1,
190+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
191+
),
192+
(
193+
DivWithoutMode,
194+
exir_ops.edge.aten.div.Tensor,
195+
1,
196+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
197+
),
198+
(
199+
DivWithNoneMode,
200+
exir_ops.edge.aten.div.Tensor_mode,
201+
1,
202+
ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
203+
),
204+
(
205+
DivWithTruncMode,
206+
exir_ops.edge.aten.div.Tensor_mode,
207+
1,
208+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
209+
),
210+
(
211+
DivWithFloorMode,
212+
exir_ops.edge.aten.div.Tensor_mode,
213+
1,
214+
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
215+
),
157216
):
158217
for second_arg_dtype in (torch.int64, torch.float, torch.double):
159218
int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64)
@@ -166,8 +225,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
166225
new_graph_module = new_prog.exported_program().graph_module
167226
self.assertIsNotNone(new_graph_module)
168227

228+
promoted_type = simple_promote_dtype(second_arg_dtype, promotion_kind)
169229
count = count_nodes_with_target_asserting_arguments_have_dtype(
170-
new_graph_module, op, second_arg_dtype
230+
self, new_graph_module, op, promoted_type
171231
)
172232
self.assertEqual(count, expected_count)
173233

0 commit comments

Comments
 (0)