Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion exir/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from types import FunctionType as function
from typing import Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Union

import torch

Expand Down Expand Up @@ -68,3 +68,25 @@ def get_control_flow_submodules(
control_flow_submodules.append(_get_submodule(graph_module, node, 0))

return control_flow_submodules

# TODO(gasoonjia): remove this and leverage core pytorch bfs_trace_with_node_process after code freeze
def bfs_trace_with_node_process(
gm: torch.fx.GraphModule, node_op: Callable[[torch.fx.Node], None]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one takes gm? I remember another one takes ep? maybe align?

Copy link
Contributor Author

@Gasoonjia Gasoonjia Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There're two bfs_trace_with_node_process, one is under ao, the other is under et.
ao's take ep and gm while et's only takes gm.
I'm ok to unify them together, just believe not that necessary, since it is an internal utility function, not an api for users to use.

) -> None:
"""Traverse the graph module and apply node_op to each node."""

assert isinstance(
gm, torch.fx.GraphModule
), f"Expected GraphModule, got {type(gm)}"

queue = [gm]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
node_op(node)

control_flow_submodules = [
submodule
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
]
queue.extend(control_flow_submodules)
57 changes: 21 additions & 36 deletions exir/passes/debug_handle_generator_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,27 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.graph_module import bfs_trace_with_node_process
from executorch.exir.pass_base import ExportPass
from torch.export import ExportedProgram
from torch.fx import GraphModule
from torch.fx.passes.infra.pass_base import PassResult


class DebugHandleGeneratorPass(ExportPass):
def call(self, graph_module: GraphModule) -> PassResult:
"""Lower a quantized reference model (with reference quantized operator patterns)
to executorch backend, that has a canonical set of quantized operators
"""

queue = [graph_module]
index = 1
# bfs to traverse all modules including control flow submodules to attached debug handle id
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
node.meta["debug_handle"] = index
index += 1
control_flow_submodules = [
submodule
for _, submodule, _ in get_control_flow_submodules(current_graph_module)
]
queue.extend(control_flow_submodules)

def _extract_debug_handles_from_node(node):
nonlocal index
node.meta["debug_handle"] = index
index += 1

bfs_trace_with_node_process(graph_module, _extract_debug_handles_from_node)

return PassResult(graph_module, True)


Expand All @@ -38,28 +33,18 @@ def generate_missing_debug_handles(ep: ExportedProgram):
This pass is used to generate missing debug handles for the graph module and its submodules.
"""

def get_control_flow_submodules_list(graph_module):
return [
submodule for _, submodule, _ in get_control_flow_submodules(graph_module)
]

max_handle = 0
queue = [ep.graph_module]

while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if "debug_handle" in node.meta:
max_handle = max(max_handle, node.meta["debug_handle"])
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
queue.extend(control_flow_submodules)
def _extract_max_debug_handle(node):
nonlocal max_handle
if "debug_handle" in node.meta:
max_handle = max(max_handle, node.meta["debug_handle"])

def _insert_new_debug_handles(node):
nonlocal max_handle
if node.meta.get("debug_handle", 0) in (0, None):
node.meta["debug_handle"] = max_handle + 1
max_handle += 1

queue = [ep.graph_module]
while queue:
current_graph_module = queue.pop(0)
for node in current_graph_module.graph.nodes:
if node.meta.get("debug_handle", 0) in (0, None):
node.meta["debug_handle"] = max_handle + 1
max_handle += 1
control_flow_submodules = get_control_flow_submodules_list(current_graph_module)
queue.extend(control_flow_submodules)
bfs_trace_with_node_process(ep.graph_module, _extract_max_debug_handle)
bfs_trace_with_node_process(ep.graph_module, _insert_new_debug_handles)
Loading