Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,17 @@
from executorch.devtools.etdump.serialize import deserialize_from_etdump_flatcc
from executorch.devtools.etrecord import ETRecord

from executorch.exir.debug_handle_utils import (
DEBUG_HANDLE_KEY,
get_greatest_ancestor_node_identifier,
)

from executorch.exir.graph_module import bfs_trace_with_node_process

from tabulate import tabulate

from torch.export import ExportedProgram

FORWARD = "forward"
EDGE_DIALECT_GRAPH_KEY = "edge_dialect_graph_module"

Expand Down Expand Up @@ -888,3 +897,71 @@ def compare_intermediate_outputs(a: Any, b: Any, comparator) -> List[float]:
else:
# Raise an error if one is a sequence and the other is not
raise ValueError("Both inputs must be sequences or both must be non-sequences.")


def propagate_back_debug_handle(
exported_program: ExportedProgram,
exported_program_graph_id: int,
edge_dialect_program: ExportedProgram,
) -> bool:
"""
Propagate debug handle from edge dialect program back to the exported program while maintain the correctness
of operator tracing.

e.g.
export program: op1 -> op2 -> op3
edge dialect program: op1_0 -> op3_0 -> op3_1
where op1_0 is from op1, op3_0 and op3_1 are from op3, op2 is removed by to_edge pipeline (e.g. RemoveNoopPass).

Then debug handle of op1 should be same as op1_0, and debug handle of op3 should be same as op3_0 and op3_1.
The debug handle of op2 will be a non-existing debug handle in edge dialect program for further skipping.

Return: True if:
a. every debug handle in the edge dialect program has a corresponding node in the exported program
b. the exported program is the greatest ancestor of the edge dialect program

Otherwise, return False.
"""

# 1. set up a mapping from debug handle to identifier of export program's node
# using edge dialect program nodes' debug handles and from_node info
export_graph_node_id_to_debug_handle = {
get_greatest_ancestor_node_identifier(node): node.meta[DEBUG_HANDLE_KEY]
for node in edge_dialect_program.graph.nodes
if node.op not in ("placeholder", "output")
}

# 2. equip debug handle to the exported program's nodes using the mapping
# number of nodes in the exported program that have matched entry in export_graph_node_id_to_debug_handle
n_matched_node = 0

# debug handle for the node in the exported program but not in the edge dialect program
debug_handle_for_removed_node = (
max(export_graph_node_id_to_debug_handle.values()) + 1
)

def _find_n_match_node(node: torch.fx.Node) -> None:
nonlocal n_matched_node
if node.name in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
if node_id in export_graph_node_id_to_debug_handle:
n_matched_node += 1

def _equip_debug_handle(node: torch.fx.Node) -> None:
if node.name in ("output", "placeholder"):
return
node_id = f"{node.name}.{exported_program_graph_id}"
if node_id in export_graph_node_id_to_debug_handle:
node.meta[DEBUG_HANDLE_KEY] = export_graph_node_id_to_debug_handle[node_id]
else:
node.meta[DEBUG_HANDLE_KEY] = debug_handle_for_removed_node

bfs_trace_with_node_process(exported_program.graph_module, _find_n_match_node)

# if any node in the edge dialect program has no corresponding node in the exported program, match failed
if n_matched_node != len(export_graph_node_id_to_debug_handle):
return False

bfs_trace_with_node_process(exported_program.graph_module, _equip_debug_handle)
return True
114 changes: 113 additions & 1 deletion devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
import unittest
from typing import Dict, Tuple

import torch
import executorch.exir.tests.models as models

import torch
from executorch.devtools import generate_etrecord, parse_etrecord

from executorch.devtools.debug_format.base_schema import (
Expand Down Expand Up @@ -41,9 +42,13 @@
map_runtime_aot_intermediate_outputs,
merge_runtime_overlapping_debug_handles,
NodeFilter,
propagate_back_debug_handle,
TimeScale,
)
from executorch.devtools.inspector.numerical_comparator import L1Comparator
from executorch.exir import to_edge
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
from torch.export import export


class TestInspectorUtils(unittest.TestCase):
Expand Down Expand Up @@ -583,6 +588,113 @@ def test_compare_intermediate_outputs_sequence_and_non_sequence(self):
with self.assertRaises(ValueError):
compare_intermediate_outputs(a, b, L1Comparator())

def test_equip_debug_handle_to_export_program_success(self):
"""Test that propagate_back_debug_handle returns True and properly equips debug handles."""
# Create a test model
model = models.FeedForwardBlock(5, 10)
inputs = (torch.rand(5, 5),)

# Export the model
exported_program = export(model, inputs)
export_graph_id = id(exported_program.graph)

# Convert to edge dialect
edge_dialect_program = to_edge(exported_program).exported_program()

# Call propagate_back_debug_handle
result = propagate_back_debug_handle(
exported_program, export_graph_id, edge_dialect_program
)

self.assertTrue(result)

# Check that debug handles are properly equipped in the exported program
exported_program_debug_handles = []
for node in exported_program.graph.nodes:
if node.op not in ("placeholder", "output"):
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
exported_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])

edge_dialect_program_debug_handles = []
for node in edge_dialect_program.graph.nodes:
if node.op not in ("placeholder", "output"):
self.assertIn(DEBUG_HANDLE_KEY, node.meta)
self.assertIsNotNone(node.meta[DEBUG_HANDLE_KEY])
edge_dialect_program_debug_handles.append(node.meta[DEBUG_HANDLE_KEY])

# The 0th operator in the exported program (layer_norm) has been decomposed into 0th and 1st ops in edge dialect graph (native_layer_norm and getitem)
# So they should have the same debug handle
self.assertEqual(
exported_program_debug_handles[0], edge_dialect_program_debug_handles[0]
)
self.assertEqual(
exported_program_debug_handles[0], edge_dialect_program_debug_handles[1]
)

def test_equip_debug_handle_to_export_program_failure(self):
"""Test that propagate_back_debug_handle returns False when there's a mismatch."""
# Create a test model
model = models.FeedForwardBlock(5, 10)
inputs = (torch.rand(5, 5),)

exported_program = export(model, inputs)
edge_dialect_program = to_edge(exported_program).exported_program()

# Create a different exported program (reexport) to cause mismatch
reexported_program = export(model, inputs)
reexport_graph_id = id(reexported_program.graph)

# Call propagate_back_debug_handle with mismatched programs
# This should return False because the reexported program has different node identifiers
result = propagate_back_debug_handle(
reexported_program, reexport_graph_id, edge_dialect_program
)

# Check that it returns False due to mismatch
self.assertFalse(result)

def test_equip_debug_handle_to_export_program_op_to_be_removed_in_to_edge(self):
"""Test that propagate_back_debug_handle returns True and properly equips debug handles when an op is removed in to_edge"""

class M(torch.nn.Module):
"""
Simple model with ops that will be removed in to_edge
"""

def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + 1
x = x.to(x.dtype)
x = x + 1
return x

inputs = (torch.rand(5, 5),)
exported_program = torch.export.export(M(), inputs)
export_graph_id = id(exported_program.graph)
edge_dialect_program = to_edge(exported_program).exported_program()

self.assertTrue(
propagate_back_debug_handle(
exported_program, export_graph_id, edge_dialect_program
)
)

# only two add ops in the exported program will keep in edge dialect program, so the debug handles for removed op will be three
debug_handle_for_removed_node = 3

for node in exported_program.graph.nodes:
if node.name == "add":
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 1)
elif node.name == "add_1":
self.assertEqual(node.meta[DEBUG_HANDLE_KEY], 2)
elif node.op not in ("placeholder", "output"):
self.assertEqual(
node.meta[DEBUG_HANDLE_KEY], debug_handle_for_removed_node
)


def gen_mock_operator_graph_with_expected_map() -> (
Tuple[OperatorGraph, Dict[int, OperatorNode]]
Expand Down
8 changes: 8 additions & 0 deletions exir/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,11 @@ python_library(
"fbsource//third-party/pypi/typing-extensions:typing-extensions",
],
)

python_library(
name = "debug_handle_utils",
srcs = ["debug_handle_utils.py"],
deps = [
"//caffe2:torch",
],
)
27 changes: 27 additions & 0 deletions exir/debug_handle_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torch.fx import Node

FROM_NODE_KEY = "from_node"
DEBUG_HANDLE_KEY = "debug_handle"


def get_greatest_ancestor_node_identifier(node: Node) -> str:
"""Get the identifier of the greatest ancestor node of the given node.

The identifier is the concatenation of the node name and graph id of the
greatest ancestor node, where the graph id is the unique id for every graph
module in the export flow and node name is unique within the same graph module.
"""

node_source = node.meta[FROM_NODE_KEY]
node_source = node_source[-1]

while len(node_source.from_node) > 0:
node_source = node_source.from_node[-1]

return f"{node_source.name}.{str(node_source.graph_id)}"
1 change: 1 addition & 0 deletions exir/passes/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ python_library(
],
deps = [
"//caffe2:torch",
"//executorch/exir:debug_handle_utils",
"//executorch/exir:graph_module",
"//executorch/exir:pass_base",
],
Expand Down
26 changes: 6 additions & 20 deletions exir/passes/debug_handle_generator_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

from typing import Dict

from executorch.exir.debug_handle_utils import (
DEBUG_HANDLE_KEY,
FROM_NODE_KEY,
get_greatest_ancestor_node_identifier,
)
from executorch.exir.graph_module import bfs_trace_with_node_process
from executorch.exir.pass_base import ExportPass
from torch.export import ExportedProgram
Expand All @@ -21,27 +26,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
greatest ancestor node in the export flow.
"""

FROM_NODE_KEY = "from_node"
DEBUG_HANDLE_KEY = "debug_handle"

source_node_id_to_debug_handle: Dict[str, int] = {}

def _get_greatest_ancestor_node_identifier(node: Node) -> str:
"""Get the identifier of the greatest ancestor node of the given node.

The identifier is the concatenation of the node name and graph id of the
greatest ancestor node, where the graph id is the unique id for every graph
module in the export flow and node name is unique within the same graph module.
"""

node_source = node.meta[FROM_NODE_KEY]
node_source = node_source[-1]

while len(node_source.from_node) > 0:
node_source = node_source.from_node[-1]

return node_source.name + str(node_source.graph_id)

def _extract_debug_handles_from_node(node: Node) -> None:
"""
Generate a debug handle based on node's oldest ancestor node's name
Expand All @@ -56,7 +42,7 @@ def _extract_debug_handles_from_node(node: Node) -> None:
FROM_NODE_KEY in node.meta
), f"Node {node} does not have meta key {FROM_NODE_KEY}"

greatest_ancestor_node_id = _get_greatest_ancestor_node_identifier(node)
greatest_ancestor_node_id = get_greatest_ancestor_node_identifier(node)

debug_handle = (
len(source_node_id_to_debug_handle) + 1
Expand Down
Loading