Skip to content

Commit 9893577

Browse files
YIWENX14facebook-github-bot
authored andcommitted
Fix errors in comparison functions when dtype is uint8
Summary: Fixed error in comparison functions that were producing wrong results with uint8 dtype. Differential Revision: D67163600
1 parent 61b9e1b commit 9893577

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

devtools/inspector/_inspector_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -372,13 +372,13 @@ def plot_metric(result: List[float], metric_name: str):
372372

373373
def calculate_mse(ref_values: ProgramOutput, values: ProgramOutput):
374374
def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
375-
return round((torch.pow((a - b).to(torch.float32), 2)).mean().item(), 2)
375+
return round((torch.pow((a - b), 2)).mean().item(), 2)
376376

377377
results = []
378378
for ref_value, value in zip(ref_values, values):
379379
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
380380
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
381-
results.append(mean_squared_error(ref_value, value))
381+
results.append(mean_squared_error(ref_value.to(torch.float32), value.to(torch.float32)))
382382
else:
383383
results.append(None)
384384

@@ -387,8 +387,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
387387

388388
def calculate_snr(ref_values: ProgramOutput, values: ProgramOutput):
389389
def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
390-
signal = signal.type(torch.float32)
391-
noise = noise.type(torch.float32)
392390
signal_power = torch.mean(torch.pow(signal, 2))
393391
noise_power = torch.mean(torch.pow(noise, 2))
394392
snr = 10 * torch.log10(signal_power / noise_power)
@@ -398,8 +396,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
398396
for ref_value, value in zip(ref_values, values):
399397
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
400398
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
401-
diff = ref_value - value
402-
snr = signal_to_noise(ref_value, diff)
399+
ref_value_fp = ref_value.to(torch.float32)
400+
value_fp = value.to(torch.float32)
401+
diff = ref_value_fp - value_fp
402+
snr = signal_to_noise(ref_value_fp, diff)
403403
results.append(snr)
404404
else:
405405
results.append(None)
@@ -429,7 +429,7 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
429429
for ref_value, value in zip(ref_values, values):
430430
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
431431
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
432-
results.append(cosine_similarity(ref_value, value))
432+
results.append(cosine_similarity(ref_value.to(torch.float32), value.to(torch.float32)))
433433
else:
434434
results.append(None)
435435

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
OperatorNode,
2020
ValueNode,
2121
)
22+
from executorch.devtools.inspector._inspector_utils import calculate_mse, calculate_snr, calculate_cosine_similarity
2223

2324
from executorch.devtools.debug_format.et_schema import FXOperatorGraph
2425
from executorch.devtools.etdump import schema_flatcc as flatcc
@@ -189,6 +190,33 @@ def test_calculate_time_scale_factor_cycles(self):
189190
)
190191

191192

193+
def test_compare_results(self):
194+
a = torch.rand(4, 4)
195+
196+
# Create tensor b which has very close value to tensor a
197+
b = a.clone()
198+
b[0,0] += 1e-2
199+
b[1,0] += 1e-2
200+
b[1,3] -= 1e-2
201+
202+
self.assertLess(calculate_mse([a], [b])[0], 0.5)
203+
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
204+
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
205+
206+
def test_compare_results_uint8(self):
207+
a = torch.randint(0, 255, (4, 4), dtype=torch.uint8)
208+
209+
# Create tensor b which has very close value to tensor a
210+
b = a.clone()
211+
b[0,0] += 1
212+
b[1,0] += 1
213+
b[1,3] -= 1
214+
215+
self.assertLess(calculate_mse([a], [b])[0], 0.5)
216+
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
217+
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)
218+
219+
192220
def gen_mock_operator_graph_with_expected_map() -> (
193221
Tuple[OperatorGraph, Dict[int, OperatorNode]]
194222
):

0 commit comments

Comments
 (0)