Skip to content

Commit 2744287

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Utility function for numerical correctness of edge dialect graphs and reference implementations (#14036)
Summary: Pull Request resolved: #14036 Created two utility functions 1. Converts an edge dialect graph into one where custom cadence op nodes are replaced with python references 2. Validates the outputs (and optionally intermediates) of the graphs Updated two tests in test_replace_ops_passes to utilize these utility functions. Differential Revision: D81843001
1 parent bd63826 commit 2744287

File tree

4 files changed

+180
-35
lines changed

4 files changed

+180
-35
lines changed

backends/cadence/aot/TARGETS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ python_library(
8282
],
8383
deps = [
8484
":utils",
85+
":ops_registrations",
86+
":ref_implementations",
8587
"//caffe2:torch",
8688
"//executorch/exir:pass_base",
8789
"//executorch/exir/dialects:lib",

backends/cadence/aot/pass_utils.py

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-strict
8-
98
from dataclasses import dataclass
10-
from typing import Callable, List, Optional, Set, Type, Union
9+
from functools import partial
10+
from operator import attrgetter
11+
from torch.utils._python_dispatch import _disable_current_modes
12+
13+
from typing import Any, Callable, cast, List, Optional, Set, Type, Union
14+
15+
import executorch.backends.cadence.aot.ops_registrations # noqa
16+
import executorch.backends.cadence.aot.ref_implementations # noqa
1117

1218
import torch
1319
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -16,6 +22,8 @@
1622
from executorch.exir.pass_base import PassBase, PassResult
1723

1824
from torch._ops import OpOverloadPacket
25+
from torch.fx import GraphModule
26+
from torch.utils._pytree import PyTree
1927

2028

2129
# Is an overlap in tensor lifetime and storage allowed at the current opt level?
@@ -114,6 +122,125 @@ def op_counts_match(
114122
return False
115123
return True
116124

125+
def validate_pass(
126+
127+
) -> Callable[[type[PassBase]], type[PassBase]]:
128+
tolerance = 1e-5
129+
log_differences = False
130+
fail_on_mismatch = True
131+
132+
def decorator(pass_class: type[PassBase]) -> type[PassBase]:
133+
class WrappedPass(pass_class):
134+
def call(self, graph_module: GraphModule) -> PassResult:
135+
# Ensure we're not in fake tensor mode for actual execution
136+
with _disable_current_modes():
137+
# Get inputs for the graph module
138+
original_inputs = self._get_concrete_inputs(graph_module)
139+
140+
if original_inputs is None:
141+
raise RuntimeError("Could not extract concrete inputs for {pass_class.__name__}")
142+
143+
# Run original graph and collect outputs
144+
with torch.no_grad():
145+
original_outputs = graph_module(*original_inputs)
146+
147+
# Apply the transformation
148+
result = super().call(graph_module)
149+
150+
# Run transformed graph and collect outputs
151+
with torch.no_grad():
152+
transformed_outputs = result.graph_module(*original_inputs)
153+
154+
# Compare outputs
155+
self._compare_outputs(
156+
original_outputs,
157+
transformed_outputs,
158+
pass_class.__name__,
159+
tolerance,
160+
log_differences,
161+
fail_on_mismatch
162+
)
163+
164+
return result
165+
166+
def _get_concrete_inputs(self, graph_module: GraphModule) -> Optional[List[torch.Tensor]]:
167+
"""Extract concrete tensor inputs from the graph module metadata."""
168+
inputs = []
169+
for node in graph_module.graph.nodes:
170+
if node.op == "placeholder":
171+
if "val" in node.meta:
172+
val = node.meta["val"]
173+
if hasattr(val, "constant") and val.constant is not None:
174+
inputs.append(val.constant.detach().clone())
175+
elif isinstance(val, torch.Tensor):
176+
# Create a concrete tensor with the same properties
177+
concrete_tensor = torch.testing.make_tensor(val.shape, dtype=val.dtype, device='cpu')
178+
# concrete_tensor = torch.randn(val.shape, dtype=val.dtype)
179+
if hasattr(val, 'device'):
180+
concrete_tensor = concrete_tensor.to(val.device)
181+
inputs.append(concrete_tensor)
182+
else:
183+
raise ValueError(f"Unsupported type for {node.name}: {type(val)}")
184+
else:
185+
raise ValueError(f"Missing 'val' in node metadata for {node.name}")
186+
return inputs
187+
188+
def _compare_outputs(
189+
self,
190+
original: Any,
191+
transformed: Any,
192+
pass_name: str,
193+
tolerance: float,
194+
log_differences: bool,
195+
fail_on_mismatch: bool
196+
) -> None:
197+
"""Compare outputs and optionally log/fail on differences."""
198+
if isinstance(original, torch.Tensor) and isinstance(transformed, torch.Tensor):
199+
if not torch.allclose(original, transformed, atol=tolerance, rtol=tolerance):
200+
max_diff = torch.max(torch.abs(original - transformed)).item()
201+
message = f"{pass_name}: Output mismatch detected. Max difference: {max_diff}"
202+
203+
if log_differences:
204+
pass
205+
# logging.warning(message)
206+
# logging.warning(f"Original shape: {original.shape}, Transformed shape: {transformed.shape}")
207+
208+
if fail_on_mismatch:
209+
raise ValueError(message)
210+
else:
211+
if log_differences:
212+
pass
213+
# logging.info(f"{pass_name}: Outputs match within tolerance {tolerance}")
214+
215+
elif isinstance(original, (list, tuple)) and isinstance(transformed, (list, tuple)):
216+
if len(original) != len(transformed):
217+
message = f"{pass_name}: Output count mismatch. Original: {len(original)}, Transformed: {len(transformed)}"
218+
if log_differences:
219+
# logging.warning(message)
220+
pass
221+
if fail_on_mismatch:
222+
raise ValueError(message)
223+
else:
224+
for i, (orig_item, trans_item) in enumerate(zip(original, transformed)):
225+
self._compare_outputs(
226+
orig_item, trans_item, f"{pass_name}[{i}]",
227+
tolerance, log_differences, fail_on_mismatch
228+
)
229+
else:
230+
if log_differences:
231+
pass
232+
# logging.info(f"{pass_name}: Non-tensor outputs, skipping numerical comparison")
233+
234+
# Preserve the original class name and documentation
235+
WrappedPass.__name__ = pass_class.__name__
236+
WrappedPass.__qualname__ = pass_class.__qualname__
237+
WrappedPass.__doc__ = pass_class.__doc__
238+
239+
return cast(type[PassBase], WrappedPass) # type: ignore[return-value]
240+
241+
return decorator
242+
243+
117244

118245
# Testing utils
119246
# Return the compute/function nodes in the graph

backends/cadence/aot/replace_ops.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
CadencePassAttribute,
3535
none_throws,
3636
register_cadence_pass,
37+
validate_pass
3738
)
3839
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
3940
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
@@ -947,7 +948,7 @@ def transpose_dims(
947948
exir_ops.edge.aten.transpose_copy.int, (proxy, dim0, dim1), {}, meta
948949
)
949950

950-
951+
@validate_pass()
951952
@register_cadence_pass(CadencePassAttribute(opt_level=3))
952953
class ReplaceConvWithChannelLastConvPass(ExportPassWithTransposeHelper):
953954
def change_nchw_to_nhwc(self, proxy: ProxyValue, meta: NodeMetadata) -> ProxyValue:
@@ -979,18 +980,18 @@ def call_operator(
979980
) -> ProxyValue:
980981
if op not in {
981982
exir_ops.edge.cadence.convolution.default,
982-
exir_ops.edge.cadence.quantized_conv_nchw.default,
983+
exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
983984
}:
984985
return super().call_operator(op, args, kwargs, meta)
985986

986-
quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.default
987+
quantized_op = op == exir_ops.edge.cadence.quantized_conv_nchw.per_tensor
987988

988989
if not quantized_op and len(args) == 8 and args[-1] is True:
989990
# Already in NHWC layout.
990991
return super().call_operator(op, args, kwargs, meta)
991992

992993
new_op = (
993-
exir_ops.edge.cadence.quantized_conv_nhwc.default
994+
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor
994995
if quantized_op
995996
else exir_ops.edge.cadence.convolution.default
996997
)

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
GraphBuilder,
1616
single_op_builder,
1717
)
18-
from executorch.backends.cadence.aot.pass_utils import count_node, op_counts_match
18+
from executorch.backends.cadence.aot.pass_utils import (
19+
count_node,
20+
op_counts_match,
21+
validate_pass
22+
)
1923
from executorch.backends.cadence.aot.replace_ops import (
2024
MakeSliceAndCatDimOutermostPass,
2125
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
@@ -1612,7 +1616,7 @@ def test_no_transpose_if_already_channel_last(self) -> None:
16121616

16131617
def create_quantized_convolution_graph_module(
16141618
self, channels_last: Optional[bool] = None
1615-
) -> torch.fx.GraphModule:
1619+
) -> tuple[torch.fx.GraphModule, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
16161620
"""Helper to create a quantized conv node.
16171621
16181622
quantized_conv(
@@ -1622,23 +1626,32 @@ def create_quantized_convolution_graph_module(
16221626
Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
16231627
"""
16241628
if channels_last:
1625-
x = torch.randn(1, 224, 56, 3)
1626-
w = torch.randn(16, 16, 16, 3)
1629+
x = torch.randint(
1630+
low=-128, high=127, size=(1, 224, 56, 3), dtype=torch.int8
1631+
)
1632+
w = torch.randint(
1633+
low=-128, high=127, size=(16, 16, 16, 3), dtype=torch.int8
1634+
)
16271635
else:
1628-
x = torch.randn(1, 3, 224, 56)
1629-
w = torch.randn(16, 3, 16, 16)
1630-
b = torch.randn(16)
1636+
x = torch.randint(
1637+
low=-128, high=127, size=(1, 3, 224, 56), dtype=torch.int8
1638+
)
1639+
w = torch.randint(
1640+
low=-128, high=127, size=(16, 3, 16, 16), dtype=torch.int8
1641+
)
1642+
1643+
b = torch.randint(low=-128, high=127, size=(16,), dtype=torch.int32)
16311644
stride = (2, 2)
16321645
padding = (0, 0)
16331646
dilation = (1, 1)
16341647
groups = 1
16351648
input_zero_point = 0
1636-
w_zero_point = torch.randn(1)
1637-
b_scale = torch.randn(1)
1649+
w_zero_point = 1
1650+
b_scale = 0.8
16381651
out_scale = 1
16391652
out_zero_point = 0
1640-
out_multiplier = torch.randn(1)
1641-
out_shift = torch.randn(1)
1653+
out_multiplier = 0
1654+
out_shift = 0
16421655
args = (
16431656
x,
16441657
w,
@@ -1661,44 +1674,39 @@ def create_quantized_convolution_graph_module(
16611674
x,
16621675
w,
16631676
b,
1664-
w_zero_point,
1665-
b_scale,
1666-
out_multiplier,
1667-
out_shift,
16681677
),
1669-
op=exir_ops.edge.cadence.quantized_conv_nhwc.default,
1678+
op=exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
16701679
args=args,
1671-
)
1680+
), (x, w, b)
16721681
else:
16731682
return single_op_builder(
16741683
placeholders=(
16751684
x,
16761685
w,
16771686
b,
1678-
w_zero_point,
1679-
b_scale,
1680-
out_multiplier,
1681-
out_shift,
16821687
),
1683-
op=exir_ops.edge.cadence.quantized_conv_nchw.default,
1688+
op=exir_ops.edge.cadence.quantized_conv_nchw.per_tensor,
16841689
args=args,
1685-
)
1690+
), (x, w, b)
16861691

16871692
def test_quantized_convolution_default_channel_last(self) -> None:
16881693
# Create a graph with a single convolution node.
1689-
gm = self.create_quantized_convolution_graph_module()
1694+
gm, (x, w, b) = self.create_quantized_convolution_graph_module()
16901695
self.assertEqual(
1691-
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.default), 1
1696+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nchw.per_tensor), 1
16921697
)
16931698
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
16941699

1700+
# self.assertTrue(numerically_equivalent(gm, (x, w, b), True))
1701+
16951702
# Apply replacement pass.
16961703
p = ReplaceConvWithChannelLastConvPass()
16971704
gm_after_replacement = p.call(gm).graph_module
16981705
# Check that no replacement was made.
16991706
self.assertEqual(
17001707
count_node(
1701-
gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default
1708+
gm_after_replacement,
1709+
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
17021710
),
17031711
1,
17041712
)
@@ -1708,26 +1716,33 @@ def test_quantized_convolution_default_channel_last(self) -> None:
17081716
3,
17091717
)
17101718

1719+
# self.assertTrue(numerically_equivalent(gm_after_replacement, (x, w, b), True))
1720+
17111721
def test_no_transpose_if_already_quantized_conv_channel_last(self) -> None:
17121722
# Create a graph with a single im2row node.
1713-
gm = self.create_quantized_convolution_graph_module(channels_last=True)
1723+
gm, (x, w, b) = self.create_quantized_convolution_graph_module(
1724+
channels_last=True
1725+
)
17141726
# Check if graph module is valid by running exportpass on it.
17151727
gm = ExportPass().call(gm).graph_module
17161728
self.assertEqual(
1717-
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.default), 1
1729+
count_node(gm, exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor), 1
17181730
)
1731+
# self.assertTrue(numerically_equivalent(gm, (x, w, b), True))
17191732

17201733
# Apply replacement pass.
17211734
p = ReplaceConvWithChannelLastConvPass()
17221735
gm_after_replacement = p.call(gm).graph_module
17231736
# Check that no replacement was made.
17241737
self.assertEqual(
17251738
count_node(
1726-
gm_after_replacement, exir_ops.edge.cadence.quantized_conv_nhwc.default
1739+
gm_after_replacement,
1740+
exir_ops.edge.cadence.quantized_conv_nhwc.per_tensor,
17271741
),
17281742
1,
17291743
)
17301744
self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
1745+
# self.assertTrue(numerically_equivalent(gm_after_replacement, (x, w, b), True))
17311746

17321747

17331748
class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase):

0 commit comments

Comments
 (0)