Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions devtools/inspector/_inspector.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,7 +658,7 @@ def _populate_debugging_related_fields(

def _associate_with_op_graph_nodes(
self,
debug_handle_to_op_node_map: Dict[int, OperatorNode],
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]],
) -> None:
"""
Helper function to populate the stack_traces, module_hierarchy and op_types attributes
Expand All @@ -676,14 +676,21 @@ def _associate_with_op_graph_nodes(
debug_handles = [debug_handles]

for handle in debug_handles:
node = debug_handle_to_op_node_map.get(handle)
# Attach node metadata including stack traces, module hierarchy and op_types to this event
if node is not None and (metadata := node.metadata) is not None:
self.stack_traces[node.name] = metadata.get("stack_trace")
self.module_hierarchy[node.name] = metadata.get("nn_module_stack")
if node.op:
# TODO: consider having this as a dict from node.name -> node.op
self.op_types += [node.op]
nodes = debug_handle_to_op_node_map.get(handle, None)
if nodes is None:
continue

for node in nodes:
# Attach node metadata including stack traces, module hierarchy and op_types to this event
if node is not None and (metadata := node.metadata) is not None:
if node.name not in self.stack_traces:
self.stack_traces[node.name] = metadata.get("stack_trace")
self.module_hierarchy[node.name] = metadata.get(
"nn_module_stack"
)
if node.op:
# TODO: consider having this as a dict from node.name -> node.op
self.op_types += [node.op]


@dataclass
Expand Down
28 changes: 18 additions & 10 deletions devtools/inspector/_inspector_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,23 @@ def gen_graphs_from_etrecord(
return op_graph_map


# One debug handle should only be associated with one node. We are in the middle of migrating debug handle generation
# from graph after to_edge to graph after torch.export, one every debug handle in exported graph may be associated with multiple nodes in to_edge
# graph. After fully migration, we should bring the bring type as well as the #node check back.
#
# Before migration: returned Dict for 1 debug handle to 1 node in to_edge graph
# During migration: returned Dict for 1 debug handle to multiple nodes in to_edge graph
# After migration: returned Dict for 1 debug handle to 1 node in exported graph
#
# TODO(gasoonjia): recover the return type to Dict[int, List[OperatorNode], reenable the #node check.
def create_debug_handle_to_op_node_mapping(
op_graph: OperatorGraph,
) -> Dict[int, OperatorNode]:
) -> Dict[int, List[OperatorNode]]:
"""
Recursive function to traverse all the operator graph nodes of input op_graph and build a mapping
from each debug handle to the operator node that contains the debug handle in its metadata.
"""
debug_handle_to_op_node_map: Dict[int, OperatorNode] = {}
debug_handle_to_op_node_map: Dict[int, List[OperatorNode]] = {}

# Recursively searches through the metadata of nodes
def _extract_debug_handles(graph: OperatorGraph):
Expand All @@ -320,14 +329,13 @@ def _extract_debug_handles(graph: OperatorGraph):
if isinstance(element, OperatorNode) and element.metadata is not None:
metadata = element.metadata
debug_handle = metadata.get("debug_handle")
if debug_handle is not None:
existing_entry = debug_handle_to_op_node_map.get(debug_handle)
if existing_entry is not None:
raise ValueError(
f"Duplicated debug handle {str(debug_handle)} shared between {element.name} and {existing_entry.name}. "
"No two op nodes of the same graph should have the same debug handle."
)
debug_handle_to_op_node_map[debug_handle] = element
if debug_handle is None:
continue

if debug_handle not in debug_handle_to_op_node_map:
debug_handle_to_op_node_map[debug_handle] = []

debug_handle_to_op_node_map[debug_handle].append(element)

# Start traversing
_extract_debug_handles(op_graph)
Expand Down
15 changes: 13 additions & 2 deletions devtools/inspector/tests/inspector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def test_inspector_associate_with_op_graph_nodes_single_debug_handle(self):

# Call the method that's under testing and verify
event_with_single_debug_handle._associate_with_op_graph_nodes(
{debug_handle: node_0}
{
debug_handle: [
node_0,
]
}
)

expected_stack_traces = {"node_0": "stack_trace_relu"}
Expand Down Expand Up @@ -226,7 +230,14 @@ def test_inspector_associate_with_op_graph_nodes_multiple_debug_handles(self):

# Call the method that's under testing and verify
event_with_multiple_debug_handles._associate_with_op_graph_nodes(
{debug_handles[0]: node_0, debug_handles[1]: node_1}
{
debug_handles[0]: [
node_0,
],
debug_handles[1]: [
node_1,
],
}
)

expected_stack_traces = {
Expand Down
53 changes: 22 additions & 31 deletions devtools/inspector/tests/inspector_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,17 @@ def get_expected_intermediate_outputs():
Returns the expected outputs of the debug handles and intermediate output mapping for this model for the given input.
"""
return {
(10,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
(11,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
(12,): torch.tensor(
[
[0.1000, 0.5000],
[0.2000, 0.6000],
[0.3000, 0.7000],
[0.4000, 0.8000],
]
),
(13,): torch.tensor([[5.0000, 14.1200]]),
(14,): torch.tensor([[5.5000, 13.6200]]),
(15,): torch.tensor([[5.4000, 13.5200]]),
(16,): torch.tensor([[10.8000, 6.7600]]),
(17,): torch.tensor([3.0000, 1.5000]),
(18,): torch.tensor([[3.6000, 4.5067]]),
(19,): torch.tensor([[3.6000, 4.5067]]),
(20,): torch.tensor([[0.9734, 0.9891]]),
(21,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
(1,): torch.tensor([[[[7.7000, 6.7000], [4.7000, 3.7000]]]]),
(2,): torch.tensor([[7.7000, 6.7000, 4.7000, 3.7000]]),
(3,): torch.tensor([[5.0000, 14.1200]]),
(4,): torch.tensor([[5.5000, 13.6200]]),
(5,): torch.tensor([[5.4000, 13.5200]]),
(6,): torch.tensor([[10.8000, 6.7600]]),
(7,): torch.tensor([3.0000, 1.5000]),
(8,): torch.tensor([[3.6000, 4.5067]]),
(9,): torch.tensor([[3.6000, 4.5067]]),
(10,): torch.tensor([[0.9734, 0.9891]]),
(11,): [torch.tensor([[0.9734]]), torch.tensor([[0.9891]])],
}

@staticmethod
Expand All @@ -89,18 +81,17 @@ def get_expected_debug_handle_to_op_name():
Returns the expected debug handle and op name mapping for this model for the given input.
"""
return {
(10,): "aten_convolution_default",
(11,): "aten_view_copy_default",
(12,): "aten_permute_copy_default",
(13,): "aten_addmm_default",
(14,): "aten_add_tensor",
(15,): "aten_sub_tensor",
(16,): "aten_mul_tensor",
(17,): "aten_add_tensor_1",
(18,): "aten_div_tensor",
(19,): "aten_relu_default",
(20,): "aten_sigmoid_default",
(21,): "aten_split_with_sizes_copy_default",
(1,): "aten_convolution_default",
(2,): "aten_view_copy_default",
(3,): "aten_addmm_default",
(4,): "aten_add_tensor",
(5,): "aten_sub_tensor",
(6,): "aten_mul_tensor",
(7,): "aten_add_tensor_1",
(8,): "aten_div_tensor",
(9,): "aten_relu_default",
(10,): "aten_sigmoid_default",
(11,): "aten_split_with_sizes_copy_default",
}


Expand Down
16 changes: 12 additions & 4 deletions devtools/inspector/tests/inspector_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
"nn_module_stack": "module_hierarchy_relu",
},
)
mapping[111] = node_fused_conv_relu
mapping[111] = [
node_fused_conv_relu,
]
node_sin = OperatorNode(
"sin",
[node_fused_conv_relu],
Expand All @@ -594,7 +596,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
"nn_module_stack": "module_hierarchy_sin",
},
)
mapping[222] = node_sin
mapping[222] = [
node_sin,
]
node_cos = OperatorNode(
"cos",
[node_sin],
Expand All @@ -605,7 +609,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
"nn_module_stack": "module_hierarchy_cos",
},
)
mapping[333] = node_cos
mapping[333] = [
node_cos,
]
node_div = OperatorNode(
"div",
[node_cos],
Expand All @@ -616,7 +622,9 @@ def gen_mock_operator_graph_with_expected_map() -> (
"nn_module_stack": "module_hierarchy_div",
},
)
mapping[444] = node_div
mapping[444] = [
node_div,
]
node_output = ValueNode("output", [node_div])
return (
OperatorGraph(
Expand Down
4 changes: 3 additions & 1 deletion exir/backend/test/qnn_backend_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def preprocess(
) -> PreprocessResult:
processed_bytes = "imqnncompiled"
all_nodes_debug_handle = [
node.meta["debug_handle"] for node in edge_program.graph.nodes
node.meta["debug_handle"]
for node in edge_program.graph.nodes
if node.op not in ("placeholder", "output")
]
return PreprocessResult(
processed_bytes=bytes(processed_bytes, encoding="utf8"),
Expand Down
6 changes: 3 additions & 3 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def forward(self, x):
program=program,
delegate=program.execution_plan[0].delegates[0],
expected_id=BackendWithCompilerDemo.__name__,
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
)

# Check the delegate instruction
Expand Down Expand Up @@ -414,7 +414,7 @@ def forward(self, x):
program=program,
delegate=program.execution_plan[0].delegates[0],
expected_id=BackendWithCompilerDemo.__name__,
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
)

# Check the delegate instruction
Expand Down Expand Up @@ -1320,7 +1320,7 @@ def forward(self, x):
program=program,
delegate=program.execution_plan[0].delegates[0],
expected_id=BackendWithCompilerDemo.__name__,
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
)

# Check the delegate instruction
Expand Down
4 changes: 2 additions & 2 deletions exir/backend/test/test_backends_lifted.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def forward(self, x):
program=program,
delegate=program.execution_plan[0].delegates[0],
expected_id=BackendWithCompilerDemo.__name__,
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
)

# Check the delegate instruction
Expand Down Expand Up @@ -437,7 +437,7 @@ def forward(self, x):
program=program,
delegate=program.execution_plan[0].delegates[0],
expected_id=BackendWithCompilerDemo.__name__,
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>2#",
expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float32<debug_handle>1#",
)

# Check the delegate instruction
Expand Down
16 changes: 14 additions & 2 deletions exir/backend/test/test_debug_handle_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,13 @@ def test_lowered_the_whole_model(self, unlift):
all_debug_handles = list(lowered_model.meta["debug_handle_map"].values())[0]
self.assertEqual(
len(all_debug_handles),
len(lowered_model.original_module.graph.nodes),
len(
[
node
for node in lowered_model.original_module.graph.nodes
if node.op not in ("placeholder", "output")
]
),
)

class ComposedModel(torch.nn.Module):
Expand Down Expand Up @@ -127,5 +133,11 @@ def forward(self, *args):
)[0]
self.assertEqual(
len(all_debug_handles),
len(lowered_node.original_module.graph.nodes),
len(
[
node
for node in lowered_node.original_module.graph.nodes
if node.op not in ("placeholder", "output")
]
),
)
Loading
Loading