Skip to content

Commit 8611b23

Browse files
author
Juntian Liu
authored
Updated the comparison logic to handle sequences separately
Differential Revision: D77893628 Pull Request resolved: #12251
1 parent f08445d commit 8611b23

File tree

7 files changed

+67
-52
lines changed

7 files changed

+67
-52
lines changed

devtools/inspector/_inspector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from executorch.devtools.etrecord import ETRecord, parse_etrecord
4343
from executorch.devtools.inspector._inspector_utils import (
4444
calculate_time_scale_factor,
45+
compare_intermediate_outputs,
4546
create_debug_handle_to_op_node_mapping,
4647
DebugHandle,
4748
display_or_print_df,
@@ -1422,8 +1423,8 @@ def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
14221423
runtime_debug_handle, runtime_debug_handle_to_op_name
14231424
),
14241425
"runtime_intermediate_output": runtime_intermediate_output,
1425-
"gap": comparator.compare(
1426-
aot_intermediate_output, runtime_intermediate_output
1426+
"gap": compare_intermediate_outputs(
1427+
aot_intermediate_output, runtime_intermediate_output, comparator
14271428
),
14281429
}
14291430
)

devtools/inspector/_inspector_utils.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -770,32 +770,29 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
770770
This function handles the following types of input:
771771
- Scalar (int or float): Converts to a tensor with a single element.
772772
- Tensor: Converts to a float64 tensor on CPU.
773-
- Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
774773
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
775774
Parameters:
776-
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
777-
a tensor, or a list of tensors.
775+
input_data (Any): The input data to be converted to a tensor. It can be a scalar
776+
or a tensor.
778777
Returns:
779778
torch.Tensor: A tensor on CPU with dtype torch.float64.
780-
Raises:
781-
ValueError: If the input_data cannot be converted to a tensor.
779+
Raises error if the input is not a scalar or a tensor
782780
"""
781+
# Assert that the input is not a Sequence
782+
assert not isinstance(input_data, Sequence)
783783
try:
784-
# Check if the input is a Sequence of tensors
785-
if isinstance(input_data, Sequence):
786-
input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data])
787784
# Try to convert the input to a tensor
788-
else:
789-
input_tensor = torch.as_tensor(input_data, dtype=torch.float64)
785+
input_tensor = torch.as_tensor(input_data, dtype=torch.float64)
790786
except Exception as e:
791787
raise ValueError(
792788
f"Cannot convert value of type {type(input_data)} to a tensor: {e}"
793789
)
794-
input_tensor = input_tensor.detach().cpu().double()
795790

791+
input_tensor = input_tensor.detach().cpu().double()
796792
# Convert NaN to 0.0
797793
if torch.isnan(input_tensor).any():
798794
input_tensor = torch.nan_to_num(input_tensor)
795+
799796
return input_tensor
800797

801798

@@ -845,3 +842,33 @@ def find_op_names(
845842
result.append(op_name)
846843

847844
return result
845+
846+
847+
def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
848+
"""
849+
Compare two outputs, handling both sequence and non-sequence cases,
850+
and return a list of comparison results.
851+
Parameters:
852+
a: The first intermediate output to compare.
853+
b: The second intermediate output to compare.
854+
comparator: A comparator object with a `compare` method.
855+
Returns:
856+
List[float]: A list of comparison results.
857+
Raises:
858+
ValueError: If one input is a sequence and the other is not, or if sequences have different lengths.
859+
"""
860+
is_a_sequence = isinstance(a, Sequence)
861+
is_b_sequence = isinstance(b, Sequence)
862+
if is_a_sequence and is_b_sequence:
863+
# Ensure both sequences have the same length
864+
if len(a) != len(b):
865+
raise ValueError("Sequences must have the same length for comparison.")
866+
867+
# Compare each element in the sequences and return the list of results
868+
return [comparator.compare(x, y) for x, y in zip(a, b)]
869+
elif not is_a_sequence and not is_b_sequence:
870+
# Compare non-sequence items and return the result in a list
871+
return [comparator.compare(a, b)]
872+
else:
873+
# Raise an error if one is a sequence and the other is not
874+
raise ValueError("Both inputs must be sequences or both must be non-sequences.")

devtools/inspector/tests/inspector_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -647,22 +647,22 @@ def test_calculate_numeric_gap(self):
647647
for i, row in df.iterrows():
648648
# Dummpy key to get the expected aot/runtime internmediate outputs
649649
key = (i,)
650-
# aot_intermediate_output should equal aot_intermediate_outputs[h]
650+
# aot_intermediate_output should equal aot_intermediate_outputs[key]
651651
self.assertTrue(
652652
torch.allclose(
653653
row["aot_intermediate_output"],
654654
aot_intermediate_outputs[key],
655655
)
656656
)
657-
# runtime_intermediate_output should equal runtime_intermediate_outputs[h]
657+
# runtime_intermediate_output should equal runtime_intermediate_outputs[key]
658658
self.assertTrue(
659659
torch.allclose(
660660
row["runtime_intermediate_output"],
661661
runtime_intermediate_outputs[key],
662662
)
663663
)
664664
# gap should equal 3.0
665-
self.assertEqual(row["gap"], 3.0)
665+
self.assertEqual(row["gap"][0], 3.0)
666666

667667
def _gen_random_float_list(self) -> List[float]:
668668
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]

devtools/inspector/tests/inspector_utils_test.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
calculate_mse,
3030
calculate_snr,
3131
calculate_time_scale_factor,
32+
compare_intermediate_outputs,
3233
convert_to_float_tensor,
3334
create_debug_handle_to_op_node_mapping,
3435
EDGE_DIALECT_GRAPH_KEY,
@@ -42,6 +43,7 @@
4243
NodeFilter,
4344
TimeScale,
4445
)
46+
from executorch.devtools.inspector.numerical_comparator import L1Comparator
4547

4648

4749
class TestInspectorUtils(unittest.TestCase):
@@ -420,19 +422,10 @@ def test_convert_input_to_tensor_convertible_inputs(self):
420422
)
421423
self.assertEqual(actual_output2.device.type, "cpu")
422424

423-
# List of tensors -> stacked tensor float32 CPU
425+
# List of tensors -> AssertionError
424426
t_list = [torch.tensor([1, 2]), torch.tensor([2, 3]), torch.tensor([3, 4])]
425-
actual_output3 = convert_to_float_tensor(t_list)
426-
self.assertIsInstance(actual_output3, torch.Tensor)
427-
self.assertEqual(actual_output3.dtype, torch.float64)
428-
self.assertEqual(tuple(actual_output3.shape), (3, 2))
429-
self.assertTrue(
430-
torch.allclose(
431-
actual_output3,
432-
torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float64),
433-
)
434-
)
435-
self.assertEqual(actual_output3.device.type, "cpu")
427+
with self.assertRaises(AssertionError):
428+
convert_to_float_tensor(t_list)
436429

437430
def test_convert_input_to_tensor_non_convertible_raises(self):
438431
class X:
@@ -566,6 +559,24 @@ def test_find_op_names_matching_handles(self):
566559
find_op_names(debug_handle, debug_handle_to_op_name), ["op1", "op2"]
567560
)
568561

562+
def test_compare_intermediate_outputs_sequences(self):
563+
a = [1.0, 2.0, 3.0]
564+
b = [1.0, 2.5, 3.5]
565+
result = compare_intermediate_outputs(a, b, L1Comparator())
566+
self.assertEqual(result, [0.0, 0.5, 0.5])
567+
568+
def test_compare_intermediate_outputs_diff_len_sequences(self):
569+
a = [1.0, 2.0]
570+
b = [1.0, 2.0, 3.0]
571+
with self.assertRaises(ValueError):
572+
compare_intermediate_outputs(a, b, L1Comparator())
573+
574+
def test_compare_intermediate_outputs_sequence_and_non_sequence(self):
575+
a = [1.0, 2.0]
576+
b = 1.0
577+
with self.assertRaises(ValueError):
578+
compare_intermediate_outputs(a, b, L1Comparator())
579+
569580

570581
def gen_mock_operator_graph_with_expected_map() -> (
571582
Tuple[OperatorGraph, Dict[int, OperatorNode]]

devtools/inspector/tests/l1_comparator_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,3 @@ def test_2D_tensors(self):
4747
expected = 14.0
4848
result = self.l1_comparator.compare(a, b)
4949
self.assertAlmostEqual(result, expected)
50-
51-
def test_list_of_tensors(self):
52-
a = [torch.tensor([2, 4]), torch.tensor([5, 2])]
53-
b = [torch.tensor([1, 2]), torch.tensor([3, 5])]
54-
expected = 8.0
55-
result = self.l1_comparator.compare(a, b)
56-
self.assertAlmostEqual(result, expected)

devtools/inspector/tests/mse_comparator_test.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,3 @@ def test_2D_tensors(self):
4747
expected = (9.0 + 49.0 + 9.0 + 36.0) / 4.0
4848
result = self.mse_comparator.compare(a, b)
4949
self.assertAlmostEqual(result, expected)
50-
51-
def test_list_of_tensors(self):
52-
a = [torch.tensor([2, 4]), torch.tensor([15, 2])]
53-
b = [torch.tensor([1, 2]), torch.tensor([9, 5])]
54-
expected = (1.0 + 4.0 + 36.0 + 9.0) / 4.0
55-
result = self.mse_comparator.compare(a, b)
56-
self.assertAlmostEqual(result, expected)

devtools/inspector/tests/snr_comparator_test.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,3 @@ def test_2D_tensors(self):
5050
expected = 10 * math.log10(37.25 / 17.0)
5151
result = self.snr_comparator.compare(a, b)
5252
self.assertAlmostEqual(result, expected)
53-
54-
def test_list_of_tensors(self):
55-
# original_power = mean(4, 16, 25, 4]) = 12.25
56-
# error = a - b = [1, 2, 2, -3] squared = [1, 4, 4, 9] mean = 18/4 = 4.5
57-
# SNR = 10 * log10(37.25/17.0)
58-
a = [torch.tensor([2, 4]), torch.tensor([5, 2])]
59-
b = [torch.tensor([1, 2]), torch.tensor([3, 5])]
60-
expected = 10 * math.log10(12.25 / 4.5)
61-
result = self.snr_comparator.compare(a, b)
62-
self.assertAlmostEqual(result, expected)

0 commit comments

Comments
 (0)