Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 28 additions & 69 deletions exir/tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,85 +127,44 @@ def setUpClass(cls) -> None:
register_additional_test_aten_ops()

def test_remove_mixed_type_operators(self) -> None:
def count_nodes_with_target_asserting_arguments_have_dtype(
new_graph_module, target, arg_dtype
):
count = 0
for node in new_graph_module.graph.nodes:
if node.op == "call_function" and node.target == target:
count += 1
for arg in node.args:
self.assertEqual(arg.meta["val"].dtype, arg_dtype)
return count

class Add(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return (x + y) + x

add = Add()

int_tensor = torch.tensor([[1, 2, 3]])
float_tensor = torch.tensor([[1.0, 2.0, 3.0]])
edge_prog = to_edge(export(add, (int_tensor, float_tensor), strict=True))

new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
new_graph_module = new_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module)

add_count = 0

for node in new_graph_module.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
):
add_count += 1
node_args = node.args
for arg in node_args:
self.assertEqual(arg.meta["val"].dtype, torch.float)

self.assertEqual(add_count, 2)

double_tensor = torch.tensor([[1.0, 2.0, 3.0]])
double_tensor = double_tensor.to(torch.double)

double_prog = to_edge(export(add, (int_tensor, double_tensor), strict=True))

double_prog.transform([RemoveMixedTypeOperators()])
new_graph_module_double = double_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module_double)

add_count_double = 0

for node in new_graph_module_double.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.add.Tensor
):
add_count_double += 1
node_args = node.args
for arg in node_args:
self.assertEqual(arg.meta["val"].dtype, torch.double)

self.assertEqual(add_count_double, 2)

class Mult(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x * y

mult = Mult()

float_tensor_vert = float_tensor.T
mult_prog = to_edge(export(mult, (int_tensor, float_tensor_vert), strict=True))

# graph_module_mult.graph.print_tabular()

mult_prog = mult_prog.transform([RemoveMixedTypeOperators()])
new_graph_module_mult = mult_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module_mult)
for module, op, expected_count in (
(Add, exir_ops.edge.aten.add.Tensor, 2),
(Mult, exir_ops.edge.aten.mul.Tensor, 1),
):
for second_arg_dtype in (torch.int64, torch.float, torch.double):
int_tensor = torch.tensor([[1, 2, 3]], dtype=torch.int64)
float_tensor = torch.tensor([[1.0, 2.0, 3.0]], dtype=second_arg_dtype)
edge_prog = to_edge(
export(module(), (int_tensor, float_tensor), strict=True)
)

mult_count = 0
new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
new_graph_module = new_prog.exported_program().graph_module
self.assertIsNotNone(new_graph_module)

for node in new_graph_module_mult.graph.nodes:
if (
node.op == "call_function"
and node.target == exir_ops.edge.aten.mul.Tensor
):
mult_count += 1
node_args = node.args
for arg in node_args:
self.assertEqual(arg.meta["val"].dtype, torch.float)

self.assertEqual(mult_count, 1)
count = count_nodes_with_target_asserting_arguments_have_dtype(
new_graph_module, op, second_arg_dtype
)
self.assertEqual(count, expected_count)

def test_remove_noop_pass(self) -> None:
class Foo(torch.nn.Module):
Expand Down
Loading