Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
from torch.fx import Node
from torch.fx.passes.infra.pass_base import PassResult

# Operator aliases for better readability.
AddMM = exir_ops.edge.aten.addmm.default
ViewCopy = exir_ops.edge.aten.view_copy.default
MM = exir_ops.edge.aten.mm.default


def insert_qdq_pair_after_node(
graph: torch.fx.Graph, anchor: torch.fx.Node, q_params: tuple
Expand Down Expand Up @@ -41,15 +46,17 @@ def insert_qdq_pair_after_node(

def _is_dequantize(node_: Node) -> bool:
return (
node_.op == "call_function"
hasattr(node_, "op")
and node_.op == "call_function"
and node_.target
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
)


def _is_quantize(node_: Node) -> bool:
return (
node_.op == "call_function"
hasattr(node_, "op")
and node_.op == "call_function"
and node_.target
== exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
)
Expand Down Expand Up @@ -82,20 +89,19 @@ class MoveLeadingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
"""

allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]

# List of approved nodes to which the <aux_node> can be connected in order for the pass to make the modification.
allowed_main_cluster_nodes = [
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.mm.default,
]
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
main_cluster_node_to_auxiliary_nodes = {
AddMM: [
ViewCopy,
],
MM: [
ViewCopy,
],
}

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
for aux_node in graph_module.graph.nodes:
if (
aux_node.op != "call_function"
or aux_node.target not in self.allowed_auxiliary_nodes
):
if aux_node.op != "call_function":
continue

dequantize_node = aux_node.args[0]
Expand All @@ -109,11 +115,13 @@ def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
continue

main_cluster_node = users[0]
if (
main_cluster_node.op != "call_function"
or main_cluster_node.target not in self.allowed_main_cluster_nodes
if main_cluster_node.op != "call_function":
continue

if aux_node.target not in self.main_cluster_node_to_auxiliary_nodes.get(
main_cluster_node.target, []
):
# Unsupported `main_cluster_node`.
# Unsupported main cluster node and auxiliary node pair.
continue

# Make sure the nodes are part of the same QDQ cluster.
Expand Down Expand Up @@ -163,29 +171,33 @@ class MoveTrailingAuxiliaryOperatorIntoSeparateQDQClusterPass(NeutronEdgePass):
"""

allowed_auxiliary_nodes = [exir_ops.edge.aten.view_copy.default]

# List of approved nodes to which the `<aux_node>` can be connected in order for the pass to make the modification.
allowed_main_cluster_nodes = [
exir_ops.edge.aten.addmm.default,
exir_ops.edge.aten.mm.default,
]
# Dictionary mapping main cluster nodes to auxiliary nodes, for which this optimization will be applied.
main_cluster_node_to_auxiliary_nodes = {
AddMM: [
ViewCopy,
],
MM: [
ViewCopy,
],
}

def run(self, graph_module: torch.fx.GraphModule) -> PassResult:

for aux_node in graph_module.graph.nodes:
if (
aux_node.op != "call_function"
or aux_node.target not in self.allowed_auxiliary_nodes
):
if aux_node.op != "call_function":
continue

main_cluster_node = aux_node.args[0]
if (
main_cluster_node.op != "call_function"
or main_cluster_node.target not in self.allowed_main_cluster_nodes
if not (
hasattr(main_cluster_node, "op")
and main_cluster_node.op == "call_function"
):
continue

if aux_node.target not in self.main_cluster_node_to_auxiliary_nodes.get(
main_cluster_node.target, []
):
# Unsupported `main_cluster_node`.
# Unsupported main cluster node and auxiliary node pair.
continue

users = list(aux_node.users.keys())
Expand Down
Loading