diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index 5b32c2fce5b..500dc527164 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -150,3 +150,20 @@ python_library( "//executorch/exir/passes:spec_prop_pass", ], ) + +python_unittest( + name = "test_graph_builder", + srcs = [ + "tests/test_graph_builder.py", + ], + typing = True, + deps = [ + "//caffe2:torch", + "//executorch/backends/cadence/aot:graph_builder", + "//executorch/backends/cadence/aot:pass_utils", + "//executorch/exir:pass_base", + "//executorch/exir/dialects:lib", + "//later:lib", + ":ops_registrations" + ], +) diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 12a2f622389..ed56a1b85fb 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -89,3 +89,12 @@ def get_node_names_list_from_gm( continue graph_nodes.append(node.name) return graph_nodes + + +def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int: + """Count the number of nodes with target `target` in the graph.""" + total = 0 + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == target: + total += 1 + return total diff --git a/backends/cadence/aot/tests/test_graph_builder.py b/backends/cadence/aot/tests/test_graph_builder.py new file mode 100644 index 00000000000..04097c17255 --- /dev/null +++ b/backends/cadence/aot/tests/test_graph_builder.py @@ -0,0 +1,70 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + + +import executorch.backends.cadence.aot.ops_registrations # noqa +import torch +from executorch.backends.cadence.aot.graph_builder import ( + GraphBuilder, + single_op_builder, +) +from executorch.backends.cadence.aot.pass_utils import count_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass +from later.unittest import TestCase + + +class TestGraphBuilder(TestCase): + def test_graph_with_single_im2row(self) -> None: + # Create a graph with a single im2row node. + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(1, 3, 224, 224)) + pad_value = builder.placeholder("pad", torch.randn(1)) + channels_last = False + im2row = builder.call_operator( + exir_ops.edge.cadence.im2row.default, + # pyre-ignore + ( + x, + (2, 2), + (1, 1), + (0, 0), + (1, 1), + pad_value, + channels_last, + ), + ) + builder.output([im2row]) + gm = builder.get_graph_module() + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + + # Check graph has a single im2row node. + self.assertEqual(len([gm.graph.nodes]), 1) + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) + + +class TestSingleOpBuilderUtility(TestCase): + def test_graph_with_single_im2row(self) -> None: + # Create a graph with a single im2row node. + x = torch.randn(1, 3, 224, 224) + pad_value = torch.randn(1) + channels_last = False + gm = single_op_builder( + (x, pad_value), + exir_ops.edge.cadence.im2row.default, + ( + x, + (2, 2), + (1, 1), + (0, 0), + (1, 1), + pad_value, + channels_last, + ), + ) + # Check if graph module is valid by running exportpass on it. + gm = ExportPass().call(gm).graph_module + + # Check graph has a single im2row node. + self.assertEqual(len([gm.graph.nodes]), 1) + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)