diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 49441416314..3faf62e008a 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -372,13 +372,15 @@ def plot_metric(result: List[float], metric_name: str): def calculate_mse(ref_values: ProgramOutput, values: ProgramOutput): def mean_squared_error(a: torch.Tensor, b: torch.Tensor): - return round((torch.pow((a - b).to(torch.float32), 2)).mean().item(), 2) + return round((torch.pow((a - b), 2)).mean().item(), 2) results = [] for ref_value, value in zip(ref_values, values): # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): - results.append(mean_squared_error(ref_value, value)) + results.append( + mean_squared_error(ref_value.to(torch.float32), value.to(torch.float32)) + ) else: results.append(None) @@ -387,8 +389,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor): def calculate_snr(ref_values: ProgramOutput, values: ProgramOutput): def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor): - signal = signal.type(torch.float32) - noise = noise.type(torch.float32) signal_power = torch.mean(torch.pow(signal, 2)) noise_power = torch.mean(torch.pow(noise, 2)) snr = 10 * torch.log10(signal_power / noise_power) @@ -398,8 +398,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor): for ref_value, value in zip(ref_values, values): # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): - diff = ref_value - value - snr = signal_to_noise(ref_value, diff) + ref_value_fp = ref_value.to(torch.float32) + value_fp = value.to(torch.float32) + diff = ref_value_fp - value_fp + snr = signal_to_noise(ref_value_fp, diff) results.append(snr) else: results.append(None) @@ -429,7 +431,9 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor): for ref_value, value in zip(ref_values, values): # TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor): - results.append(cosine_similarity(ref_value, value)) + results.append( + cosine_similarity(ref_value.to(torch.float32), value.to(torch.float32)) + ) else: results.append(None) diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index 73511f5fcd7..5e224415bb6 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -25,6 +25,9 @@ from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord from executorch.devtools.inspector._inspector_utils import ( + calculate_cosine_similarity, + calculate_mse, + calculate_snr, calculate_time_scale_factor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, @@ -188,6 +191,32 @@ def test_calculate_time_scale_factor_cycles(self): calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1 ) + def test_compare_results(self): + a = torch.rand(4, 4) + + # Create tensor b which has very close value to tensor a + b = a.clone() + b[0, 0] += 1e-2 + b[1, 0] += 1e-2 + b[1, 3] -= 1e-2 + + self.assertLess(calculate_mse([a], [b])[0], 0.5) + self.assertGreater(calculate_snr([a], [b])[0], 30.0) + self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0) + + def test_compare_results_uint8(self): + a = torch.randint(0, 255, (4, 4), dtype=torch.uint8) + + # Create tensor b which has very close value to tensor a + b = a.clone() + b[0, 0] += 1 + b[1, 0] += 1 + b[1, 3] -= 1 + + self.assertLess(calculate_mse([a], [b])[0], 0.5) + self.assertGreater(calculate_snr([a], [b])[0], 30.0) + self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]]