Skip to content

Commit 3713a3f

Browse files
author
ssjia
committed
Update
[ghstack-poisoned]
2 parents 2048d61 + 76a4062 commit 3713a3f

File tree

18 files changed

+587
-32
lines changed

18 files changed

+587
-32
lines changed

.github/workflows/pull.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,8 @@ jobs:
801801
id-token: write
802802
contents: read
803803
strategy:
804+
matrix:
805+
enable-etdump: ['', '--enable-etdump']
804806
fail-fast: false
805807
with:
806808
runner: linux.2xlarge
@@ -820,7 +822,7 @@ jobs:
820822
source .ci/scripts/setup-emscripten.sh
821823
822824
# Test selective build
823-
bash scripts/build_wasm_tests.sh
825+
bash scripts/build_wasm_tests.sh ${{ matrix.enable-etdump }}
824826
825827
# Install Jest
826828
cd cmake-out-wasm/extension/wasm/test
@@ -892,12 +894,10 @@ jobs:
892894
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
893895
894896
# Test models serially
895-
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh mv2
896-
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh mv3
897-
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh edsr
898-
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh resnet18
899-
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh resnet50
900-
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh dl3
897+
models="mv2 mv3 edsr resnet18 resnet50 dl3"
898+
for model in $models; do
899+
python -m examples.vulkan.export --model_name=$model --test
900+
done
901901
902902
903903

backends/qualcomm/tests/tester.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def __init__(
5252
default_partitioner_cls=QnnPartitioner,
5353
)
5454

55-
def run(self, artifact: ExportedProgram, inputs=None) -> None:
55+
def run(
56+
self, artifact: ExportedProgram, inputs=None, generate_etrecord: bool = False
57+
) -> None:
5658
ep = QnnPassManager().transform_for_export_pipeline(artifact)
5759
transform_passes = QnnPassManager().get_to_edge_transform_passes(ep)
5860

@@ -61,6 +63,7 @@ def run(self, artifact: ExportedProgram, inputs=None) -> None:
6163
transform_passes=transform_passes,
6264
partitioner=self.partitioners,
6365
compile_config=self.edge_compile_conf,
66+
generate_etrecord=generate_etrecord,
6467
)
6568

6669

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from dataclasses import dataclass
2+
3+
import torch
4+
from torch.ao.ns.fx.utils import compute_sqnr
5+
6+
7+
@dataclass
8+
class TensorStatistics:
9+
"""Contains summary statistics for a tensor."""
10+
11+
shape: torch.Size
12+
""" The shape of the tensor. """
13+
14+
numel: int
15+
""" The number of elements in the tensor. """
16+
17+
median: float
18+
""" The median of the tensor. """
19+
20+
mean: float
21+
""" The mean of the tensor. """
22+
23+
max: torch.types.Number
24+
""" The maximum element of the tensor. """
25+
26+
min: torch.types.Number
27+
""" The minimum element of the tensor. """
28+
29+
@classmethod
30+
def from_tensor(cls, tensor: torch.Tensor) -> "TensorStatistics":
31+
"""Creates a TensorStatistics object from a tensor."""
32+
flattened = torch.flatten(tensor)
33+
return cls(
34+
shape=tensor.shape,
35+
numel=tensor.numel(),
36+
median=torch.quantile(flattened, q=0.5).item(),
37+
mean=flattened.mean().item(),
38+
max=flattened.max().item(),
39+
min=flattened.min().item(),
40+
)
41+
42+
43+
@dataclass
44+
class ErrorStatistics:
45+
"""Contains statistics derived from the difference of two tensors."""
46+
47+
reference_stats: TensorStatistics
48+
""" Statistics for the reference tensor. """
49+
50+
actual_stats: TensorStatistics
51+
""" Statistics for the actual tensor. """
52+
53+
error_l2_norm: float | None
54+
""" The L2 norm of the error between the actual and reference tensor. """
55+
56+
error_mae: float | None
57+
""" The mean absolute error between the actual and reference tensor. """
58+
59+
error_max: float | None
60+
""" The maximum absolute elementwise error between the actual and reference tensor. """
61+
62+
error_msd: float | None
63+
""" The mean signed deviation between the actual and reference tensor. """
64+
65+
sqnr: float | None
66+
""" The signal-to-quantization-noise ratio between the actual and reference tensor. """
67+
68+
@classmethod
69+
def from_tensors(
70+
cls, actual: torch.Tensor, reference: torch.Tensor
71+
) -> "ErrorStatistics":
72+
"""Creates an ErrorStatistics object from two tensors."""
73+
actual = actual.to(torch.float64)
74+
reference = reference.to(torch.float64)
75+
76+
if actual.shape != reference.shape:
77+
return cls(
78+
reference_stats=TensorStatistics.from_tensor(reference),
79+
actual_stats=TensorStatistics.from_tensor(actual),
80+
error_l2_norm=None,
81+
error_mae=None,
82+
error_max=None,
83+
error_msd=None,
84+
sqnr=None,
85+
)
86+
87+
error = actual - reference
88+
flat_error = torch.flatten(error)
89+
90+
return cls(
91+
reference_stats=TensorStatistics.from_tensor(reference),
92+
actual_stats=TensorStatistics.from_tensor(actual),
93+
error_l2_norm=torch.linalg.norm(flat_error).item(),
94+
error_mae=torch.mean(torch.abs(flat_error)).item(),
95+
error_max=torch.max(torch.abs(flat_error)).item(),
96+
error_msd=torch.mean(flat_error).item(),
97+
# Torch sqnr implementation requires float32 due to decorator logic
98+
sqnr=compute_sqnr(actual.to(torch.float), reference.to(torch.float)).item(),
99+
)

backends/test/harness/stages/to_edge_transform_and_lower.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
to_edge_transform_and_lower,
88
)
99
from executorch.exir.backend.partitioner import Partitioner
10+
1011
from torch.export import ExportedProgram
1112

1213

@@ -24,11 +25,14 @@ def __init__(
2425
def stage_type(self) -> StageType:
2526
return StageType.TO_EDGE_TRANSFORM_AND_LOWER
2627

27-
def run(self, artifact: ExportedProgram, inputs=None) -> None:
28+
def run(
29+
self, artifact: ExportedProgram, inputs=None, generate_etrecord: bool = False
30+
) -> None:
2831
self.edge_dialect_program = to_edge_transform_and_lower(
2932
artifact,
3033
compile_config=self.edge_compile_conf,
3134
partitioner=self.partitioners,
35+
generate_etrecord=generate_etrecord,
3236
)
3337

3438
@property

backends/test/harness/tester.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
78
from executorch.backends.test.harness.stages import (
89
Export,
910
Partition,
@@ -182,10 +183,10 @@ def _post(self, stage):
182183
assert stage_type in self.stages
183184
self.stages[stage_type] = stage
184185

185-
def _run_stage(self, stage_instance, inputs=None):
186+
def _run_stage(self, stage_instance, inputs=None, *args, **kwargs):
186187
assert isinstance(stage_instance, Stage)
187188
prev_stage_artifact = self._pre(stage_instance)
188-
stage_instance.run(prev_stage_artifact, inputs=inputs)
189+
stage_instance.run(prev_stage_artifact, inputs=inputs, *args, **kwargs) # noqa
189190
self._post(stage_instance)
190191
return self
191192

@@ -212,11 +213,14 @@ def to_edge(self, to_edge_stage: Optional[ToEdge] = None):
212213
return res
213214

214215
def to_edge_transform_and_lower(
215-
self, to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None
216+
self,
217+
to_edge_and_transform_stage: Optional[ToEdgeTransformAndLower] = None,
218+
generate_etrecord: bool = False,
216219
):
217220
return self._run_stage(
218221
to_edge_and_transform_stage
219-
or self._get_default_stage(StageType.TO_EDGE_TRANSFORM_AND_LOWER)
222+
or self._get_default_stage(StageType.TO_EDGE_TRANSFORM_AND_LOWER),
223+
generate_etrecord=generate_etrecord,
220224
)
221225

222226
def run_passes(self, run_passes_stage: Optional[RunPasses] = None):
@@ -302,20 +306,15 @@ def run_method_and_compare_outputs(
302306
atol=1e-03,
303307
rtol=1e-03,
304308
qtol=0,
309+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
305310
):
306311
number_of_runs = 1 if inputs is not None else num_runs
307312
reference_stage = self.stages[StageType.EXPORT]
308313

309314
stage = stage or self.cur
310315

311-
print(f"Comparing Stage {stage} with Stage {reference_stage}")
312-
for run_iteration in range(number_of_runs):
316+
for _ in range(number_of_runs):
313317
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
314-
input_shapes = [
315-
generated_input.shape if hasattr(generated_input, "shape") else None
316-
for generated_input in inputs_to_run
317-
]
318-
print(f"Run {run_iteration} with input shapes: {input_shapes}")
319318

320319
# Reference output (and quantization scale)
321320
(
@@ -328,13 +327,25 @@ def run_method_and_compare_outputs(
328327
# Output from running artifact at stage
329328
stage_output = self.stages[stage].run_artifact(inputs_to_run)
330329
self._compare_outputs(
331-
reference_output, stage_output, quantization_scale, atol, rtol, qtol
330+
reference_output,
331+
stage_output,
332+
quantization_scale,
333+
atol,
334+
rtol,
335+
qtol,
336+
statistics_callback,
332337
)
333338

334339
return self
335340

336341
@staticmethod
337-
def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
342+
def _assert_outputs_equal(
343+
model_output,
344+
ref_output,
345+
atol=1e-03,
346+
rtol=1e-03,
347+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
348+
):
338349
"""
339350
Helper testing function that asserts that the model output and the reference output
340351
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -349,6 +360,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
349360
for i in range(len(model_output)):
350361
model = model_output[i]
351362
ref = ref_output[i]
363+
364+
error_stats = ErrorStatistics.from_tensors(model, ref)
365+
if statistics_callback is not None:
366+
statistics_callback(error_stats)
367+
352368
assert (
353369
ref.shape == model.shape
354370
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
@@ -386,6 +402,7 @@ def _compare_outputs(
386402
atol=1e-03,
387403
rtol=1e-03,
388404
qtol=0,
405+
statistics_callback: Callable[[ErrorStatistics], None] | None = None,
389406
):
390407
"""
391408
Compares the original of the original nn module with the output of the generated artifact.
@@ -408,6 +425,7 @@ def _compare_outputs(
408425
reference_output,
409426
atol=atol,
410427
rtol=rtol,
428+
statistics_callback=statistics_callback,
411429
)
412430

413431
@staticmethod
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.test.harness.error_statistics import ErrorStatistics
5+
6+
7+
class ErrorStatisticsTests(unittest.TestCase):
8+
def test_error_stats_simple(self):
9+
tensor1 = torch.tensor([1, 2, 3, 4])
10+
tensor2 = torch.tensor([2, 2, 2, 5])
11+
12+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
13+
14+
# Check actual tensor statistics
15+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
16+
self.assertEqual(error_stats.actual_stats.numel, 4)
17+
self.assertEqual(error_stats.actual_stats.median, 2.5)
18+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
19+
self.assertEqual(error_stats.actual_stats.max, 4)
20+
self.assertEqual(error_stats.actual_stats.min, 1)
21+
22+
# Check reference tensor statistics
23+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([4]))
24+
self.assertEqual(error_stats.reference_stats.numel, 4)
25+
self.assertEqual(error_stats.reference_stats.median, 2.0)
26+
self.assertEqual(error_stats.reference_stats.mean, 2.75)
27+
self.assertEqual(error_stats.reference_stats.max, 5)
28+
self.assertEqual(error_stats.reference_stats.min, 2)
29+
30+
# Check error statistics
31+
self.assertAlmostEqual(error_stats.error_l2_norm, 1.732, places=3)
32+
self.assertEqual(error_stats.error_mae, 0.75)
33+
self.assertEqual(error_stats.error_max, 1.0)
34+
self.assertEqual(error_stats.error_msd, -0.25)
35+
self.assertAlmostEqual(error_stats.sqnr, 10.0, places=3)
36+
37+
def test_error_stats_different_shapes(self):
38+
# Create tensors with different shapes
39+
tensor1 = torch.tensor([1, 2, 3, 4])
40+
tensor2 = torch.tensor([[2, 3], [4, 5]])
41+
42+
error_stats = ErrorStatistics.from_tensors(tensor1, tensor2)
43+
44+
# Check actual tensor statistics
45+
self.assertEqual(error_stats.actual_stats.shape, torch.Size([4]))
46+
self.assertEqual(error_stats.actual_stats.numel, 4)
47+
self.assertEqual(error_stats.actual_stats.median, 2.5)
48+
self.assertEqual(error_stats.actual_stats.mean, 2.5)
49+
self.assertEqual(error_stats.actual_stats.max, 4)
50+
self.assertEqual(error_stats.actual_stats.min, 1)
51+
52+
# Check reference tensor statistics
53+
self.assertEqual(error_stats.reference_stats.shape, torch.Size([2, 2]))
54+
self.assertEqual(error_stats.reference_stats.numel, 4)
55+
self.assertEqual(error_stats.reference_stats.median, 3.5)
56+
self.assertEqual(error_stats.reference_stats.mean, 3.5)
57+
self.assertEqual(error_stats.reference_stats.max, 5)
58+
self.assertEqual(error_stats.reference_stats.min, 2)
59+
60+
# Check that all error values are None when shapes differ
61+
self.assertIsNone(error_stats.error_l2_norm)
62+
self.assertIsNone(error_stats.error_mae)
63+
self.assertIsNone(error_stats.error_max)
64+
self.assertIsNone(error_stats.error_msd)
65+
self.assertIsNone(error_stats.sqnr)

0 commit comments

Comments
 (0)