Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,20 @@ python_library(
"//executorch/exir/passes:spec_prop_pass",
],
)

python_unittest(
name = "test_graph_builder",
srcs = [
"tests/test_graph_builder.py",
],
typing = True,
deps = [
"//caffe2:torch",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//later:lib",
":ops_registrations"
],
)
9 changes: 9 additions & 0 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,12 @@ def get_node_names_list_from_gm(
continue
graph_nodes.append(node.name)
return graph_nodes


def count_node(graph_module: torch.fx.GraphModule, target: torch.fx.node.Target) -> int:
"""Count the number of nodes with target `target` in the graph."""
total = 0
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == target:
total += 1
return total
70 changes: 70 additions & 0 deletions backends/cadence/aot/tests/test_graph_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.


import executorch.backends.cadence.aot.ops_registrations # noqa
import torch
from executorch.backends.cadence.aot.graph_builder import (
GraphBuilder,
single_op_builder,
)
from executorch.backends.cadence.aot.pass_utils import count_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from later.unittest import TestCase


class TestGraphBuilder(TestCase):
def test_graph_with_single_im2row(self) -> None:
# Create a graph with a single im2row node.
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(1, 3, 224, 224))
pad_value = builder.placeholder("pad", torch.randn(1))
channels_last = False
im2row = builder.call_operator(
exir_ops.edge.cadence.im2row.default,
# pyre-ignore
(
x,
(2, 2),
(1, 1),
(0, 0),
(1, 1),
pad_value,
channels_last,
),
)
builder.output([im2row])
gm = builder.get_graph_module()
# Check if graph module is valid by running exportpass on it.
gm = ExportPass().call(gm).graph_module

# Check graph has a single im2row node.
self.assertEqual(len([gm.graph.nodes]), 1)
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)


class TestSingleOpBuilderUtility(TestCase):
def test_graph_with_single_im2row(self) -> None:
# Create a graph with a single im2row node.
x = torch.randn(1, 3, 224, 224)
pad_value = torch.randn(1)
channels_last = False
gm = single_op_builder(
(x, pad_value),
exir_ops.edge.cadence.im2row.default,
(
x,
(2, 2),
(1, 1),
(0, 0),
(1, 1),
pad_value,
channels_last,
),
)
# Check if graph module is valid by running exportpass on it.
gm = ExportPass().call(gm).graph_module

# Check graph has a single im2row node.
self.assertEqual(len([gm.graph.nodes]), 1)
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
Loading