Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,13 +372,15 @@ def plot_metric(result: List[float], metric_name: str):

def calculate_mse(ref_values: ProgramOutput, values: ProgramOutput):
def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
return round((torch.pow((a - b).to(torch.float32), 2)).mean().item(), 2)
return round((torch.pow((a - b), 2)).mean().item(), 2)

results = []
for ref_value, value in zip(ref_values, values):
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
results.append(mean_squared_error(ref_value, value))
results.append(
mean_squared_error(ref_value.to(torch.float32), value.to(torch.float32))
)
else:
results.append(None)

Expand All @@ -387,8 +389,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor):

def calculate_snr(ref_values: ProgramOutput, values: ProgramOutput):
def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
signal = signal.type(torch.float32)
noise = noise.type(torch.float32)
signal_power = torch.mean(torch.pow(signal, 2))
noise_power = torch.mean(torch.pow(noise, 2))
snr = 10 * torch.log10(signal_power / noise_power)
Expand All @@ -398,8 +398,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
for ref_value, value in zip(ref_values, values):
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
diff = ref_value - value
snr = signal_to_noise(ref_value, diff)
ref_value_fp = ref_value.to(torch.float32)
value_fp = value.to(torch.float32)
diff = ref_value_fp - value_fp
snr = signal_to_noise(ref_value_fp, diff)
results.append(snr)
else:
results.append(None)
Expand Down Expand Up @@ -429,7 +431,9 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
for ref_value, value in zip(ref_values, values):
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
if isinstance(ref_value, torch.Tensor) and isinstance(value, torch.Tensor):
results.append(cosine_similarity(ref_value, value))
results.append(
cosine_similarity(ref_value.to(torch.float32), value.to(torch.float32))
)
else:
results.append(None)

Expand Down
29 changes: 29 additions & 0 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord
from executorch.devtools.inspector._inspector_utils import (
calculate_cosine_similarity,
calculate_mse,
calculate_snr,
calculate_time_scale_factor,
create_debug_handle_to_op_node_mapping,
EDGE_DIALECT_GRAPH_KEY,
Expand Down Expand Up @@ -188,6 +191,32 @@ def test_calculate_time_scale_factor_cycles(self):
calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1
)

def test_compare_results(self):
a = torch.rand(4, 4)

# Create tensor b which has very close value to tensor a
b = a.clone()
b[0, 0] += 1e-2
b[1, 0] += 1e-2
b[1, 3] -= 1e-2

self.assertLess(calculate_mse([a], [b])[0], 0.5)
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)

def test_compare_results_uint8(self):
a = torch.randint(0, 255, (4, 4), dtype=torch.uint8)

# Create tensor b which has very close value to tensor a
b = a.clone()
b[0, 0] += 1
b[1, 0] += 1
b[1, 3] -= 1

self.assertLess(calculate_mse([a], [b])[0], 0.5)
self.assertGreater(calculate_snr([a], [b])[0], 30.0)
self.assertAlmostEqual(calculate_cosine_similarity([a], [b])[0], 1.0)


def gen_mock_operator_graph_with_expected_map() -> (
Tuple[OperatorGraph, Dict[int, OperatorNode]]
Expand Down
Loading