Skip to content

Commit 6616c6f

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Add SNR numerical comparator
Summary: This PR introduces the SNR (Signal-to-Noise Ratio) comparator class, which extends the numerical comparison framework to evaluate model accuracy by comparing the SNR between two inputs. The comparator calculates the signal power and noise power from the two inputs, then computes the SNR using the formula SNR = 10 * log10(original_power / error_power). Differential Revision: D77159515
1 parent 7f2fcb0 commit 6616c6f

File tree

8 files changed

+128
-7
lines changed

8 files changed

+128
-7
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,4 +720,8 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
720720
f"Cannot convert value of type {type(input_data)} to a tensor: {e}"
721721
)
722722
input_tensor = input_tensor.detach().cpu().double()
723+
724+
# Convert NaN to 0.0
725+
if torch.isnan(input_tensor).any():
726+
input_tensor = torch.nan_to_num(input_tensor)
723727
return input_tensor

devtools/inspector/numerical_comparator/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,21 @@ python_library(
2727
],
2828
)
2929

30+
python_library(
31+
name = "snr_numerical_comparator",
32+
srcs = ["snr_numerical_comparator.py"],
33+
deps = [
34+
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
35+
"//executorch/devtools/inspector:inspector_utils",
36+
],
37+
)
38+
3039
python_library(
3140
name = "lib",
3241
srcs = ["__init__.py"],
3342
deps = [
3443
":l1_numerical_comparator",
3544
":mse_numerical_comparator",
45+
":snr_numerical_comparator",
3646
],
3747
)

devtools/inspector/numerical_comparator/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,9 @@
1313
MSEComparator,
1414
)
1515

16+
from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import (
17+
SNRComparator,
18+
)
19+
1620

17-
__all__ = ["L1Comparator", "MSEComparator"]
21+
__all__ = ["L1Comparator", "MSEComparator", "SNRComparator"]

devtools/inspector/numerical_comparator/l1_numerical_comparator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ def compare(self, a: Any, b: Any) -> float:
1919

2020
t_a = convert_to_float_tensor(a)
2121
t_b = convert_to_float_tensor(b)
22-
if torch.isnan(t_a).any() or torch.isnan(t_b).any():
23-
t_a = torch.nan_to_num(t_a)
24-
t_b = torch.nan_to_num(t_b)
2522

2623
try:
2724
res = torch.abs(t_a - t_b).sum().item()

devtools/inspector/numerical_comparator/mse_numerical_comparator.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ def compare(self, a: Any, b: Any) -> float:
1919

2020
t_a = convert_to_float_tensor(a)
2121
t_b = convert_to_float_tensor(b)
22-
if torch.isnan(t_a).any() or torch.isnan(t_b).any():
23-
t_a = torch.nan_to_num(t_a)
24-
t_b = torch.nan_to_num(t_b)
2522

2623
try:
2724
res = float(torch.mean(torch.square(t_a - t_b)))
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import Any
9+
10+
import torch
11+
from executorch.devtools.inspector._inspector_utils import convert_to_float_tensor
12+
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
13+
NumericalComparatorBase,
14+
)
15+
16+
17+
class SNRComparator(NumericalComparatorBase):
18+
def compare(self, a: Any, b: Any) -> float:
19+
"""
20+
Compare the Signal-to-Noise Ratio (SNR) between two inputs
21+
Formula: SNR = 10 * log10(original_power / error_power)
22+
"""
23+
24+
t_a = convert_to_float_tensor(a)
25+
t_b = convert_to_float_tensor(b)
26+
27+
# Calculate the signal power and noise power
28+
original_power = torch.mean(torch.pow(t_a, 2))
29+
try:
30+
error = t_a - t_b
31+
error_power = torch.mean(torch.pow(error, 2))
32+
except Exception as e:
33+
raise ValueError(
34+
f"Error computing SNR difference between tensors: {str(e)}"
35+
)
36+
37+
# Calculate SNR
38+
snr = 10 * torch.log10(original_power / error_power)
39+
return snr.item()

devtools/inspector/tests/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ python_unittest(
7070
],
7171
)
7272

73+
python_unittest(
74+
name = "snr_comparator_test",
75+
srcs = ["snr_comparator_test.py"],
76+
deps = [
77+
"//executorch/devtools/inspector/numerical_comparator:lib",
78+
],
79+
)
80+
7381
python_library(
7482
name = "inspector_test_utils",
7583
srcs = [
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import math
8+
import unittest
9+
10+
import torch
11+
12+
from executorch.devtools.inspector.numerical_comparator import SNRComparator
13+
14+
15+
class TestSNRComparator(unittest.TestCase):
16+
snr_comparator = SNRComparator()
17+
18+
def test_identical_tensors(self):
19+
# identical tensors --> error_power == 0 --> SNR is inf
20+
a = torch.tensor([[10, 4], [3, 4]])
21+
b = torch.tensor([[10, 4], [3, 4]])
22+
result = self.snr_comparator.compare(a, b)
23+
self.assertTrue(math.isinf(result) and result > 0)
24+
25+
def test_scalar(self):
26+
# original_power == 1, error_power == 1 --> SNR = 10 * log10(1/1) = 0
27+
a = 1
28+
b = 2
29+
result = self.snr_comparator.compare(a, b)
30+
self.assertAlmostEqual(result, 0.0)
31+
32+
def test_with_nans_replaced_with_zero(self):
33+
a = torch.tensor([float("nan"), 1.0])
34+
b = torch.tensor([0.0, 1.0])
35+
result = self.snr_comparator.compare(a, b)
36+
self.assertTrue(math.isinf(result) and result > 0)
37+
38+
def test_shape_mismatch_raises_exception(self):
39+
a = torch.tensor([1, 2, -1])
40+
b = torch.tensor([1, 1, -3, 4])
41+
with self.assertRaises(ValueError):
42+
self.snr_comparator.compare(a, b)
43+
44+
def test_2D_tensors(self):
45+
# original_power = mean([16, 81, 36, 16]) = 37.25
46+
# error = a - b = [3, 7, 3, -1] squared = [9, 49, 9, 1] mean = 68/4 = 17.0
47+
# SNR = 10 * log10(37.25/17.0)
48+
a = torch.tensor([[4, 9], [6, 4]])
49+
b = torch.tensor([[1, 2], [3, 5]])
50+
expected = 10 * math.log10(37.25 / 17.0)
51+
result = self.snr_comparator.compare(a, b)
52+
self.assertAlmostEqual(result, expected)
53+
54+
def test_list_of_tensors(self):
55+
# original_power = mean(4, 16, 25, 4]) = 12.25
56+
# error = a - b = [1, 2, 2, -3] squared = [1, 4, 4, 9] mean = 18/4 = 4.5
57+
# SNR = 10 * log10(37.25/17.0)
58+
a = [torch.tensor([2, 4]), torch.tensor([5, 2])]
59+
b = [torch.tensor([1, 2]), torch.tensor([3, 5])]
60+
expected = 10 * math.log10(12.25 / 4.5)
61+
result = self.snr_comparator.compare(a, b)
62+
self.assertAlmostEqual(result, expected)

0 commit comments

Comments
 (0)