|
9 | 9 | import os |
10 | 10 | import tempfile |
11 | 11 | import unittest |
12 | | -from typing import List, Optional, Tuple |
| 12 | +from typing import Callable, List, Optional, Tuple |
13 | 13 |
|
14 | 14 | import executorch.exir as exir |
15 | 15 |
|
|
71 | 71 | from functorch.experimental import control_flow |
72 | 72 |
|
73 | 73 | from torch import nn |
| 74 | +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND |
74 | 75 | from torch.export import export |
75 | 76 | from torch.export.graph_signature import InputKind, InputSpec, TensorArgument |
76 | 77 | from torch.fx import GraphModule, subgraph_rewriter |
@@ -121,91 +122,114 @@ def foo_out( |
121 | 122 | return a + 1, None |
122 | 123 |
|
123 | 124 |
|
| 125 | +def simple_promote_dtype( |
| 126 | + dtype: torch.dtype, promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND |
| 127 | +) -> torch.dtype: |
| 128 | + if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT: |
| 129 | + return dtype |
| 130 | + if promotion_kind == ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT: |
| 131 | + return dtype if dtype.is_floating_point else torch.float |
| 132 | + else: |
| 133 | + raise Exception(f"Unsupported promotion kind {promotion_kind}") |
| 134 | + |
| 135 | + |
| 136 | +def count_nodes_with_target_asserting_arguments_have_dtype( |
| 137 | + self, module, target, arg_dtype |
| 138 | +) -> int: |
| 139 | + count = 0 |
| 140 | + for node in module.graph.nodes: |
| 141 | + if node.op == "call_function" and node.target == target: |
| 142 | + count += 1 |
| 143 | + for arg in node.args: |
| 144 | + self.assertEqual(arg.meta["val"].dtype, arg_dtype) |
| 145 | + return count |
| 146 | + |
| 147 | + |
124 | 148 | class TestPasses(unittest.TestCase): |
125 | 149 | @classmethod |
126 | 150 | def setUpClass(cls) -> None: |
127 | 151 | register_additional_test_aten_ops() |
128 | 152 |
|
129 | 153 | def test_remove_mixed_type_operators(self) -> None: |
130 | | - class Add(torch.nn.Module): |
131 | | - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
132 | | - return (x + y) + x |
133 | | - |
134 | | - add = Add() |
135 | | - |
136 | | - int_tensor = torch.tensor([[1, 2, 3]]) |
137 | | - float_tensor = torch.tensor([[1.0, 2.0, 3.0]]) |
138 | | - edge_prog = to_edge(export(add, (int_tensor, float_tensor), strict=True)) |
139 | | - |
140 | | - new_prog = edge_prog.transform([RemoveMixedTypeOperators()]) |
141 | | - new_graph_module = new_prog.exported_program().graph_module |
142 | | - self.assertIsNotNone(new_graph_module) |
143 | | - |
144 | | - add_count = 0 |
145 | | - |
146 | | - for node in new_graph_module.graph.nodes: |
147 | | - if ( |
148 | | - node.op == "call_function" |
149 | | - and node.target == exir_ops.edge.aten.add.Tensor |
150 | | - ): |
151 | | - add_count += 1 |
152 | | - node_args = node.args |
153 | | - for arg in node_args: |
154 | | - self.assertEqual(arg.meta["val"].dtype, torch.float) |
155 | | - |
156 | | - self.assertEqual(add_count, 2) |
157 | | - |
158 | | - double_tensor = torch.tensor([[1.0, 2.0, 3.0]]) |
159 | | - double_tensor = double_tensor.to(torch.double) |
160 | | - |
161 | | - double_prog = to_edge(export(add, (int_tensor, double_tensor), strict=True)) |
162 | | - |
163 | | - double_prog.transform([RemoveMixedTypeOperators()]) |
164 | | - new_graph_module_double = double_prog.exported_program().graph_module |
165 | | - self.assertIsNotNone(new_graph_module_double) |
166 | | - |
167 | | - add_count_double = 0 |
168 | | - |
169 | | - for node in new_graph_module_double.graph.nodes: |
170 | | - if ( |
171 | | - node.op == "call_function" |
172 | | - and node.target == exir_ops.edge.aten.add.Tensor |
173 | | - ): |
174 | | - add_count_double += 1 |
175 | | - node_args = node.args |
176 | | - for arg in node_args: |
177 | | - self.assertEqual(arg.meta["val"].dtype, torch.double) |
178 | | - |
179 | | - self.assertEqual(add_count_double, 2) |
180 | | - |
181 | | - class Mult(torch.nn.Module): |
182 | | - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
183 | | - return x * y |
184 | | - |
185 | | - mult = Mult() |
186 | | - |
187 | | - float_tensor_vert = float_tensor.T |
188 | | - mult_prog = to_edge(export(mult, (int_tensor, float_tensor_vert), strict=True)) |
189 | | - |
190 | | - # graph_module_mult.graph.print_tabular() |
| 154 | + def make_module(fwd: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): |
| 155 | + class Module(torch.nn.Module): |
| 156 | + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 157 | + return fwd(x, y) |
| 158 | + |
| 159 | + return Module |
| 160 | + |
| 161 | + Add = make_module(lambda x, y: (x + y) + x) |
| 162 | + Mult = make_module(lambda x, y: x * y) |
| 163 | + Minimum = make_module(torch.minimum) |
| 164 | + DivWithoutMode = make_module(torch.div) |
| 165 | + DivWithNoneMode = make_module(lambda x, y: torch.div(x, y, rounding_mode=None)) |
| 166 | + DivWithTruncMode = make_module( |
| 167 | + lambda x, y: torch.div(x, y, rounding_mode="trunc") |
| 168 | + ) |
| 169 | + DivWithFloorMode = make_module( |
| 170 | + lambda x, y: torch.div(x, y, rounding_mode="floor") |
| 171 | + ) |
191 | 172 |
|
192 | | - mult_prog = mult_prog.transform([RemoveMixedTypeOperators()]) |
193 | | - new_graph_module_mult = mult_prog.exported_program().graph_module |
194 | | - self.assertIsNotNone(new_graph_module_mult) |
| 173 | + for module, op, expected_count, promotion_kind in ( |
| 174 | + ( |
| 175 | + Add, |
| 176 | + exir_ops.edge.aten.add.Tensor, |
| 177 | + 2, |
| 178 | + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
| 179 | + ), |
| 180 | + ( |
| 181 | + Mult, |
| 182 | + exir_ops.edge.aten.mul.Tensor, |
| 183 | + 1, |
| 184 | + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
| 185 | + ), |
| 186 | + ( |
| 187 | + Minimum, |
| 188 | + exir_ops.edge.aten.minimum.default, |
| 189 | + 1, |
| 190 | + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
| 191 | + ), |
| 192 | + ( |
| 193 | + DivWithoutMode, |
| 194 | + exir_ops.edge.aten.div.Tensor, |
| 195 | + 1, |
| 196 | + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
| 197 | + ), |
| 198 | + ( |
| 199 | + DivWithNoneMode, |
| 200 | + exir_ops.edge.aten.div.Tensor_mode, |
| 201 | + 1, |
| 202 | + ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, |
| 203 | + ), |
| 204 | + ( |
| 205 | + DivWithTruncMode, |
| 206 | + exir_ops.edge.aten.div.Tensor_mode, |
| 207 | + 1, |
| 208 | + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
| 209 | + ), |
| 210 | + ( |
| 211 | + DivWithFloorMode, |
| 212 | + exir_ops.edge.aten.div.Tensor_mode, |
| 213 | + 1, |
| 214 | + ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, |
| 215 | + ), |
| 216 | + ): |
| 217 | + for second_arg_dtype in (torch.int64, torch.float, torch.double): |
| 218 | + int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64) |
| 219 | + float_tensor = torch.tensor([[1.0, 2.0, 3.0]], dtype=second_arg_dtype) |
| 220 | + edge_prog = to_edge( |
| 221 | + export(module(), (int_tensor, float_tensor), strict=True) |
| 222 | + ) |
195 | 223 |
|
196 | | - mult_count = 0 |
| 224 | + new_prog = edge_prog.transform([RemoveMixedTypeOperators()]) |
| 225 | + new_graph_module = new_prog.exported_program().graph_module |
| 226 | + self.assertIsNotNone(new_graph_module) |
197 | 227 |
|
198 | | - for node in new_graph_module_mult.graph.nodes: |
199 | | - if ( |
200 | | - node.op == "call_function" |
201 | | - and node.target == exir_ops.edge.aten.mul.Tensor |
202 | | - ): |
203 | | - mult_count += 1 |
204 | | - node_args = node.args |
205 | | - for arg in node_args: |
206 | | - self.assertEqual(arg.meta["val"].dtype, torch.float) |
207 | | - |
208 | | - self.assertEqual(mult_count, 1) |
| 228 | + promoted_type = simple_promote_dtype(second_arg_dtype, promotion_kind) |
| 229 | + count = count_nodes_with_target_asserting_arguments_have_dtype( |
| 230 | + self, new_graph_module, op, promoted_type |
| 231 | + ) |
| 232 | + self.assertEqual(count, expected_count) |
209 | 233 |
|
210 | 234 | def test_remove_noop_pass(self) -> None: |
211 | 235 | class Foo(torch.nn.Module): |
|
0 commit comments