Skip to content

Commit ef9f31f

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Utility function for numerical correctness of edge dialect graphs and reference implementations (pytorch#14036)
Summary: Created two utility functions 1. Converts an edge dialect graph into one where custom cadence op nodes are replaced with python references 2. Validates the outputs (and optionally intermediates) of the graphs Updated two tests in test_replace_ops_passes to utilize these utility functions. Differential Revision: D81843001
1 parent 54d5b0f commit ef9f31f

File tree

4 files changed

+204
-33
lines changed

4 files changed

+204
-33
lines changed

backends/cadence/aot/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ python_library(
8282
],
8383
deps = [
8484
":utils",
85+
":ops_registrations",
86+
":ref_implementations",
8587
"//caffe2:torch",
8688
"//executorch/exir:pass_base",
8789
"//executorch/exir/dialects:lib",

backends/cadence/aot/pass_utils.py

Lines changed: 155 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,22 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
98
from dataclasses import dataclass
9+
from functools import partial
10+
from operator import attrgetter
1011
from typing import Callable, List, Optional, Set, Type, Union
1112

13+
import executorch.backends.cadence.aot.ops_registrations # noqa
14+
import executorch.backends.cadence.aot.ref_implementations # noqa
15+
1216
import torch
1317
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
1418

1519
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
1620
from executorch.exir.pass_base import PassBase, PassResult
1721

1822
from torch._ops import OpOverloadPacket
23+
from torch.utils._pytree import PyTree
1924

2025

2126
# Is an overlap in tensor lifetime and storage allowed at the current opt level?
@@ -115,6 +120,155 @@ def op_counts_match(
115120
return True
116121

117122

123+
def construct_reference_graph_module(
124+
graph_module: torch.fx.GraphModule,
125+
) -> torch.fx.GraphModule:
126+
"""
127+
Given a graph module in edge dialect, construct a new graph module with the same
128+
structure as the input graph module, but with all cadence custom op nodes
129+
replaced with their corresponding reference implementations in torch.ops.cadence.<name>.
130+
"""
131+
new_graph = torch.fx.Graph()
132+
val_map = {}
133+
134+
def _get_cadence_op_with_overload(node: torch.fx.Node) -> Optional[str]:
135+
"""Get full cadence operation name with overload."""
136+
if not (node.op == "call_function" and isinstance(node.target, EdgeOpOverload)):
137+
return None
138+
139+
schema_name = node.target._schema.name
140+
if not schema_name.startswith("cadence::"):
141+
return None
142+
143+
base_op_name = schema_name.split("::", 1)[1]
144+
prefix = f"cadence_{base_op_name}_"
145+
146+
return (
147+
f"{base_op_name}.{node.name[len(prefix):]}"
148+
if node.name.startswith(prefix)
149+
else base_op_name
150+
)
151+
152+
for node in graph_module.graph.nodes:
153+
if node.op == "call_function" and isinstance(node.target, EdgeOpOverload):
154+
# Schema name format: "namespace::operation_name"
155+
op = _get_cadence_op_with_overload(node)
156+
if op is None: # Copy the nodes as-is
157+
new_node = new_graph.node_copy(node, lambda n: val_map[n])
158+
val_map[node] = new_node
159+
continue
160+
161+
try:
162+
ref_op = attrgetter(op)(torch.ops.cadence)
163+
except AttributeError:
164+
raise RuntimeError(
165+
f"Could not find reference implementation for {op} in {torch.ops.cadence}"
166+
)
167+
new_node = new_graph.create_node(
168+
node.op,
169+
ref_op,
170+
args=tuple(
171+
val_map[arg] if isinstance(arg, torch.fx.Node) else arg
172+
for arg in node.args
173+
),
174+
kwargs={
175+
k: val_map[v] if isinstance(v, torch.fx.Node) else v
176+
for k, v in node.kwargs.items()
177+
},
178+
name=node.name,
179+
)
180+
val_map[node] = new_node
181+
else:
182+
# Copy all other nodes as-is
183+
new_node = new_graph.node_copy(node, lambda n: val_map[n])
184+
val_map[node] = new_node
185+
186+
# Create a new GraphModule with the new graph and the same code as the original
187+
return torch.fx.GraphModule(graph_module, new_graph)
188+
189+
190+
def numerically_equivalent(
191+
graph_module: torch.fx.GraphModule,
192+
example_inputs: tuple[torch.Tensor, ...],
193+
exact_match: bool,
194+
rtol: float = 1e-3,
195+
atol: float = 1e-3,
196+
validate_intermediates: bool = False,
197+
) -> Union[bool, tuple[bool, dict[str, torch.Tensor], dict[str, torch.Tensor]]]:
198+
"""
199+
Constructs a new GraphModule from the input graph_module, replacing all cadence EdgeOpOverload
200+
nodes with their corresponding reference implementations in
201+
executorch.backends.cadence.aot.ref_implementations (i.e., torch.ops.cadence.<name>).
202+
All aten nodes are left unchanged.
203+
204+
Args:
205+
graph_module: The input graph module to be checked for numerical equivalence.
206+
example_inputs: Example inputs to the graph module.
207+
exact_match: If True, the outputs the original and transformed graph modules must be exactly equal.
208+
rtol: Relative tolerance for torch.allclose. Unused if exact_match is True.
209+
atol: Absolute tolerance for torch.allclose. Unused if exact_match is True.
210+
validate_intermediates: If True, also check that the intermediate values of the original and transformed
211+
graph modules are numerically equivalent. If False, only check that the final outputs are equivalent.
212+
213+
Returns:
214+
True if the original and transformed graph modules are numerically equivalent, False otherwise. Raises
215+
an error if the cadence reference implementation does not exist.
216+
"""
217+
218+
# Create a new GraphModule with the new graph and the same code as the original
219+
new_graph_module = construct_reference_graph_module(graph_module)
220+
221+
# Add forward hooks to capture all intermediates from both original and new GraphModules
222+
orig_intermediates: list[PyTree] = []
223+
ref_intermediates: list[PyTree] = []
224+
225+
def get_orig_intermediate(
226+
module: torch.fx.GraphModule, input: PyTree, output: PyTree
227+
) -> None:
228+
nonlocal orig_intermediates
229+
orig_intermediates.append(output)
230+
231+
def get_new_intermediate(
232+
module: torch.fx.GraphModule, input: PyTree, output: PyTree
233+
) -> None:
234+
nonlocal ref_intermediates
235+
ref_intermediates.append(output)
236+
237+
hooks = []
238+
if validate_intermediates:
239+
for module in graph_module.modules():
240+
hooks.append(module.register_forward_hook(get_orig_intermediate))
241+
242+
for module in new_graph_module.modules():
243+
# Don't bother saving hooks for new graph module since we're
244+
# throwing out the new graph after this function call
245+
module.register_forward_hook(get_new_intermediate)
246+
247+
orig_outs = graph_module(*example_inputs)
248+
new_outs = new_graph_module(*example_inputs)
249+
for hook in hooks:
250+
hook.remove()
251+
252+
if not validate_intermediates:
253+
orig_intermediates = [orig_outs]
254+
ref_intermediates = [new_outs]
255+
256+
assert (
257+
len(orig_intermediates) == len(ref_intermediates)
258+
and len(orig_intermediates) > 0
259+
)
260+
if exact_match:
261+
comparison_func = torch.equal
262+
else:
263+
comparison_func = partial(torch.allclose, rtol=rtol, atol=atol, equal_nan=False)
264+
265+
close_tree = torch.utils._pytree.tree_map(
266+
comparison_func, orig_intermediates, ref_intermediates
267+
)
268+
close_leaves, _ = torch.utils._pytree.tree_flatten(close_tree)
269+
return all(close_leaves)
270+
271+
118272
# Testing utils
119273
# Return the compute/function nodes in the graph
120274
def get_compute_nodes_in_gm(graph_module: torch.fx.GraphModule) -> List[torch.fx.Node]:

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -979,18 +979,18 @@ def call_operator(
979979
) -> ProxyValue:
980980
if op not in {
981981
exir_ops.edge.cadence.convolution.default,
982-
exir_ops.edge.cadence.quantized_conv_nchw.default,
982+
exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
983983
}:
984984
return super().call_operator(op, args, kwargs, meta)
985985

986-
quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.default
986+
quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.per_tensor
987987

988988
if not quantized_op and len(args) == 8 and args[-1] is True:
989989
# Already in NHWC layout.
990990
return super().call_operator(op, args, kwargs, meta)
991991

992992
new_op = (
993-
exir_ops.edge.cadence.quantized_conv_nhwc.default
993+
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor
994994
if quantized_op
995995
else exir_ops.edge.cadence.convolution.default
996996
)

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
GraphBuilder,
1616
single_op_builder,
1717
)
18-
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
18+
from executorch.backends.cadence.aot.pass_utils import (
19+
count_node,
20+
numerically_equivalent,
21+
op_counts_match,
22+
)
1923
from executorch.backends.cadence.aot.replace_ops import (
2024
MakeSliceAndCatDimOutermostPass,
2125
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
@@ -1611,7 +1615,7 @@ def test_no_transpose_if_already_channel_last(self) -> None:
16111615

16121616
def create_quantized_convolution_graph_module(
16131617
self, channels_last: Optional[bool] = None
1614-
) -> torch.fx.GraphModule:
1618+
) -> tuple[torch.fx.GraphModule, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
16151619
"""Helper to create a quantized conv node.
16161620
16171621
quantized_conv(
@@ -1621,23 +1625,32 @@ def create_quantized_convolution_graph_module(
16211625
Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
16221626
"""
16231627
if channels_last:
1624-
x = torch.randn(1, 224, 56, 3)
1625-
w = torch.randn(16, 16, 16, 3)
1628+
x = torch.randint(
1629+
low=-128, high=127, size=(1, 224, 56, 3), dtype=torch.int8
1630+
)
1631+
w = torch.randint(
1632+
low=-128, high=127, size=(16, 16, 16, 3), dtype=torch.int8
1633+
)
16261634
else:
1627-
x = torch.randn(1, 3, 224, 56)
1628-
w = torch.randn(16, 3, 16, 16)
1629-
b = torch.randn(16)
1635+
x = torch.randint(
1636+
low=-128, high=127, size=(1, 3, 224, 56), dtype=torch.int8
1637+
)
1638+
w = torch.randint(
1639+
low=-128, high=127, size=(16, 3, 16, 16), dtype=torch.int8
1640+
)
1641+
1642+
b = torch.randint(low=-128, high=127, size=(16,), dtype=torch.int32)
16301643
stride = (2, 2)
16311644
padding = (0, 0)
16321645
dilation = (1, 1)
16331646
groups = 1
16341647
input_zero_point = 0
1635-
w_zero_point = torch.randn(1)
1636-
b_scale = torch.randn(1)
1648+
w_zero_point = 1
1649+
b_scale = 0.8
16371650
out_scale = 1
16381651
out_zero_point = 0
1639-
out_multiplier = torch.randn(1)
1640-
out_shift = torch.randn(1)
1652+
out_multiplier = 0
1653+
out_shift = 0
16411654
args = (
16421655
x,
16431656
w,
@@ -1660,44 +1673,39 @@ def create_quantized_convolution_graph_module(
16601673
x,
16611674
w,
16621675
b,
1663-
w_zero_point,
1664-
b_scale,
1665-
out_multiplier,
1666-
out_shift,
16671676
),
1668-
op=exir_ops.edge.cadence.quantized_conv_nhwc.default,
1677+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
16691678
args=args,
1670-
)
1679+
), (x, w, b)
16711680
else:
16721681
return single_op_builder(
16731682
placeholders=(
16741683
x,
16751684
w,
16761685
b,
1677-
w_zero_point,
1678-
b_scale,
1679-
out_multiplier,
1680-
out_shift,
16811686
),
1682-
op=exir_ops.edge.cadence.quantized_conv_nchw.default,
1687+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
16831688
args=args,
1684-
)
1689+
), (x, w, b)
16851690

16861691
def test_quantized_convolution_default_channel_last(self) -> None:
16871692
# Create a graph with a single convolution node.
1688-
gm = self.create_quantized_convolution_graph_module()
1693+
gm, (x, w, b) = self.create_quantized_convolution_graph_module()
16891694
self.assertEqual(
1690-
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.default), 1
1695+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), 1
16911696
)
16921697
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
16931698

1699+
self.assertTrue(numerically_equivalent(gm, (x, w, b), True))
1700+
16941701
# Apply replacement pass.
16951702
p = ReplaceConvWithChannelLastConvPass()
16961703
gm_after_replacement = p.call(gm).graph_module
16971704
# Check that no replacement was made.
16981705
self.assertEqual(
16991706
count_node(
1700-
gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default
1707+
gm_after_replacement,
1708+
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
17011709
),
17021710
1,
17031711
)
@@ -1707,26 +1715,33 @@ def test_quantized_convolution_default_channel_last(self) -> None:
17071715
3,
17081716
)
17091717

1718+
self.assertTrue(numerically_equivalent(gm_after_replacement, (x, w, b), True))
1719+
17101720
def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None:
17111721
# Create a graph with a single im2row node.
1712-
gm = self.create_quantized_convolution_graph_module(channels_last=True)
1722+
gm, (x, w, b) = self.create_quantized_convolution_graph_module(
1723+
channels_last=True
1724+
)
17131725
# Check if graph module is valid by running exportpass on it.
17141726
gm = ExportPass().call(gm).graph_module
17151727
self.assertEqual(
1716-
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.default), 1
1728+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), 1
17171729
)
1730+
self.assertTrue(numerically_equivalent(gm, (x, w, b), True))
17181731

17191732
# Apply replacement pass.
17201733
p = ReplaceConvWithChannelLastConvPass()
17211734
gm_after_replacement = p.call(gm).graph_module
17221735
# Check that no replacement was made.
17231736
self.assertEqual(
17241737
count_node(
1725-
gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default
1738+
gm_after_replacement,
1739+
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
17261740
),
17271741
1,
17281742
)
17291743
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
1744+
self.assertTrue(numerically_equivalent(gm_after_replacement, (x, w, b), True))
17301745

17311746

17321747
class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase):

0 commit comments

Comments
 (0)