Skip to content

Commit 0735510

Browse files
Juntian Liufacebook-github-bot
authored andcommitted
Add inspector numeric gap calculation between AOT and runtime intermediate outputs
Summary: This PR introduces a method to calculate the numeric gap between logged intermediate outputs from an exported graph and runtime outputs. The method currently supports MSE and L1 distance metrics for comparison. It maps corresponding intermediate outputs from both stages and computes the numerical gaps, returning the results in a pandas DataFrame. This enhancement aids in identifying discrepancies between AOT intermediate outputs and actual intermediate outputs during runtime. Reviewed By: Gasoonjia Differential Revision: D76831086
1 parent 18e4240 commit 0735510

File tree

5 files changed

+142
-9
lines changed

5 files changed

+142
-9
lines changed

devtools/inspector/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ python_library(
1919
"//executorch/devtools/etrecord:etrecord",
2020
"//executorch/exir:lib",
2121
"//executorch/devtools/inspector:intermediate_output_capturer",
22+
"//executorch/devtools/inspector/numerical_comparator:lib",
2223
],
2324
)
2425

devtools/inspector/_inspector.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
inflate_runtime_output,
5656
is_debug_output,
5757
is_inference_output_equal,
58+
map_runtime_aot_intermediate_outputs,
5859
ProgramOutput,
5960
RESERVED_FRAMEWORK_EVENT_NAMES,
6061
TimeScale,
@@ -63,6 +64,10 @@
6364
from executorch.devtools.inspector._intermediate_output_capturer import (
6465
IntermediateOutputCapturer,
6566
)
67+
from executorch.devtools.inspector.numerical_comparator import (
68+
L1Comparator,
69+
MSEComparator,
70+
)
6671
from executorch.exir import ExportedProgram
6772

6873

@@ -1337,3 +1342,50 @@ def get_exported_program(
13371342
if graph is None
13381343
else self._etrecord.graph_map.get(graph)
13391344
)
1345+
1346+
def calculate_numeric_gap(self, distance: str = "MSE") -> pd.DataFrame:
1347+
"""
1348+
Compares logged intermediate outputs from the exported graph (in ETRecord)
1349+
with runtime outputs (in ETDump) using a user-specific numerical comparator.
1350+
1351+
Args:
1352+
distance: the metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
1353+
1354+
Returns:
1355+
pd.DataFrame: A DataFrame listing corresponding operator outputs from
1356+
both stages and their computed numerical gaps.
1357+
"""
1358+
if self._aot_intermediate_outputs is None:
1359+
raise ValueError(
1360+
"The aot intermediate outputs is required but not populated."
1361+
)
1362+
mapping = map_runtime_aot_intermediate_outputs(
1363+
self._aot_intermediate_outputs, self._get_runtime_intermediate_outputs()
1364+
)
1365+
metric = distance.strip().upper()
1366+
if metric == "MSE":
1367+
comparator = MSEComparator()
1368+
elif metric == "L1":
1369+
comparator = L1Comparator()
1370+
else:
1371+
raise ValueError(f"Unsupported distance metric {distance!r}")
1372+
1373+
rows = []
1374+
for (aot_debug_handle, aot_intermediate_output), (
1375+
runtime_debug_handle,
1376+
runtime_intermediate_output,
1377+
) in mapping.items():
1378+
if aot_intermediate_output is None or runtime_intermediate_output is None:
1379+
continue
1380+
rows.append(
1381+
{
1382+
"aot_debug_handle": aot_debug_handle,
1383+
"aot_intermediate_output": aot_intermediate_output,
1384+
"runtime_debug_handle": runtime_debug_handle,
1385+
"runtime_intermediate_output": runtime_intermediate_output,
1386+
"gap": comparator.compare(
1387+
aot_intermediate_output, runtime_intermediate_output
1388+
),
1389+
}
1390+
)
1391+
return pd.DataFrame(rows)

devtools/inspector/_inspector_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import math
1010
import sys
11+
from collections.abc import Sequence
1112
from dataclasses import dataclass
1213
from enum import Enum
1314
from typing import Any, Dict, IO, List, Mapping, Optional, Tuple, TypeAlias, Union
@@ -676,17 +677,25 @@ def map_runtime_aot_intermediate_outputs(
676677
# Map only if both AOT and runtime data are present.
677678
if len(aot_list) != 0 and len(runtime_list) != 0:
678679
# Combine aot debug handles into a single key
679-
aot_combined_debug_handle, aot_output = (
680+
aot_combined_debug_handle, aot_intermediate_output = (
680681
_combine_overlapped_intermediate_outputs(aot_list)
681682
)
682683
# Combine runtime debug handles into a single key
683-
runtime_combined_debug_handle, runtime_output = (
684+
runtime_combined_debug_handle, runtime_intermediate_output = (
684685
_combine_overlapped_intermediate_outputs(runtime_list)
685686
)
687+
# List can't be used as a key, so convert to tuple
688+
if isinstance(aot_intermediate_output, list):
689+
aot_intermediate_output = tuple(aot_intermediate_output)
690+
# runtime follow the same format as aot, so it's safe to convert to tuple
691+
if isinstance(runtime_intermediate_output, list):
692+
runtime_intermediate_output = tuple(runtime_intermediate_output)
686693
# Create a mapping between runtime and aot
687-
aot_runtime_mapping[(aot_combined_debug_handle, aot_output)] = (
694+
aot_runtime_mapping[
695+
(aot_combined_debug_handle, aot_intermediate_output)
696+
] = (
688697
runtime_combined_debug_handle,
689-
runtime_output,
698+
runtime_intermediate_output,
690699
)
691700

692701
return aot_runtime_mapping
@@ -698,7 +707,7 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
698707
This function handles the following types of input:
699708
- Scalar (int or float): Converts to a tensor with a single element.
700709
- Tensor: Converts to a float64 tensor on CPU.
701-
- List of Tensors: Stacks the tensors into a single float64 tensor on CPU.
710+
- Sequence of Tensors: Stacks the tensors into a single float64 tensor on CPU.
702711
The resulting tensor is detached, moved to CPU, and cast to torch.float64.
703712
Parameters:
704713
input_data (Any): The input data to be converted to a tensor. It can be a scalar,
@@ -709,8 +718,8 @@ def convert_to_float_tensor(input_data: Any) -> torch.Tensor:
709718
ValueError: If the input_data cannot be converted to a tensor.
710719
"""
711720
try:
712-
# Check if the input is a list of tensors
713-
if isinstance(input_data, list):
721+
# Check if the input is a Sequence of tensors
722+
if isinstance(input_data, Sequence):
714723
input_tensor = torch.stack([convert_to_float_tensor(a) for a in input_data])
715724
# Try to convert the input to a tensor
716725
else:

devtools/inspector/numerical_comparator/TARGETS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ python_library(
1414
srcs = ["l1_numerical_comparator.py"],
1515
deps = [
1616
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
17-
"//executorch/devtools/inspector:lib",
17+
"//executorch/devtools/inspector:inspector_utils",
1818
],
1919
)
2020

@@ -23,7 +23,7 @@ python_library(
2323
srcs = ["mse_numerical_comparator.py"],
2424
deps = [
2525
"//executorch/devtools/inspector/numerical_comparator:numerical_comparator_base",
26-
"//executorch/devtools/inspector:lib",
26+
"//executorch/devtools/inspector:inspector_utils",
2727
],
2828
)
2929

devtools/inspector/tests/inspector_test.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from unittest.mock import patch
1919

20+
import pandas as pd
21+
2022
import torch
2123
import torch.fx
2224

@@ -578,6 +580,75 @@ def test_get_runtime_intermediate_outputs(self):
578580
self.assertIn((key,), runtime_outputs)
579581
self.assertEqual(len(runtime_outputs[(key,)]), RAW_DATA_SIZE)
580582

583+
def test_calculate_numeric_gap(self):
584+
# Create a context manager to patch functions called by Inspector.__init__
585+
with patch.object(
586+
_inspector, "parse_etrecord", return_value=None
587+
), patch.object(
588+
_inspector, "gen_etdump_object", return_value=None
589+
), patch.object(
590+
EventBlock, "_gen_from_etdump"
591+
), patch.object(
592+
_inspector, "gen_graphs_from_etrecord"
593+
):
594+
# Call the constructor of Inspector
595+
inspector_instance = Inspector(
596+
etdump_path=ETDUMP_PATH,
597+
etrecord=ETRECORD_PATH,
598+
)
599+
600+
aot_intermediate_outputs = {
601+
(0,): torch.tensor([1.0, 2.0, 3.0]),
602+
(1,): torch.tensor([4.0, 5.0, 6.0]),
603+
}
604+
605+
runtime_intermediate_outputs = {
606+
(0,): torch.tensor([2.0, 1.0, 4.0]),
607+
(1,): torch.tensor([3.0, 6.0, 5.0]),
608+
}
609+
610+
inspector_instance._aot_intermediate_outputs = aot_intermediate_outputs
611+
inspector_instance._get_runtime_intermediate_outputs = (
612+
lambda: runtime_intermediate_outputs
613+
)
614+
615+
df = inspector_instance.calculate_numeric_gap(distance="L1")
616+
self.assertIsInstance(df, pd.DataFrame)
617+
self.assertEqual(len(df), 2)
618+
cols = set(df.columns)
619+
expected_cols = {
620+
"aot_debug_handle",
621+
"aot_intermediate_output",
622+
"runtime_debug_handle",
623+
"runtime_intermediate_output",
624+
"gap",
625+
}
626+
self.assertEqual(cols, expected_cols)
627+
founded_aot_debug_handle = set(df["aot_debug_handle"])
628+
self.assertEqual(
629+
founded_aot_debug_handle, set(aot_intermediate_outputs.keys())
630+
)
631+
for _, row in df.iterrows():
632+
aot_debuh_handle = row["aot_debug_handle"]
633+
# aot_intermediate_output should equal aot_intermediate_outputs[h]
634+
self.assertTrue(
635+
torch.allclose(
636+
row["aot_intermediate_output"],
637+
aot_intermediate_outputs[aot_debuh_handle],
638+
)
639+
)
640+
# runtime_debug_hanlde equals aot_debug_handle at this case
641+
self.assertEqual(row["runtime_debug_handle"], aot_debuh_handle)
642+
# runtime_intermediate_output should equal runtime_intermediate_outputs[h]
643+
self.assertTrue(
644+
torch.allclose(
645+
row["runtime_intermediate_output"],
646+
runtime_intermediate_outputs[aot_debuh_handle],
647+
)
648+
)
649+
# gap should equal 3.0
650+
self.assertEqual(row["gap"], 3.0)
651+
581652
def _gen_random_float_list(self) -> List[float]:
582653
return [random.uniform(0, 10) for _ in range(RAW_DATA_SIZE)]
583654

0 commit comments

Comments
 (0)