Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 68 additions & 10 deletions exir/backend/canonical_partitioners/config_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,24 @@
import torch
from executorch.exir.backend.backend_details import ExportedProgram
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_partitions_from_list_of_nodes,
generate_grouped_partitions_from_list_of_nodes,
)
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)

from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param
from torch.fx.passes.infra.partitioner import Partition


def is_constant_data(ep: ExportedProgram, node: torch.fx.Node) -> bool:
return (
is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node)
)


def format_target_name(target_name: str) -> str:
"""
We remove the dialect name space from the target name. We generally
Expand Down Expand Up @@ -100,6 +108,35 @@ def get_partition(
pass


class DSJ:
"""
Disjoint set union data structure used to find connected components in the graph.
"""

def __init__(self):
self.parent = {}

def find(self, x):
self.parent.setdefault(x, x)
if self.parent[x] != x:
self.parent[x] = self.find(self.parent[x])
return self.parent[x]

def union(self, x, y):
self.parent[self.find(x)] = self.find(y)

def contains(self, x):
return x in self.parent

def gen_groups(self):
groups = {}
for node in self.parent.keys():
root = self.find(node)
groups.setdefault(root, set()).add(node)

return [list(group) for group in groups.values()]


class ConfigerationBasedPartitioner(Partitioner):
def __init__(
self,
Expand Down Expand Up @@ -162,23 +199,44 @@ def filter_fn(node: torch.fx.Node) -> bool:
def get_matched_nodes_from_configs(
self, ep: ExportedProgram
) -> List[List[torch.fx.Node]]:
# disjoint set union for merging partitions
dsj = DSJ()

# gather supported nodes
matched_nodes = []
gm = ep.graph_module
for node in gm.graph.nodes:
if node.op == "call_function":
target = format_target_name(node.target.__name__)
if target in self.target_partitioner_configs:
node_config = self.target_partitioner_configs[target]
if node_config.check_constraints(node, ep):
matched_nodes.append(node_config.get_partition(node, ep))
if node.op != "call_function":
continue
target = format_target_name(node.target.__name__)

if target not in self.target_partitioner_configs:
continue

node_config = self.target_partitioner_configs[target]
if not node_config.check_constraints(node, ep):
continue

partition_candidate = node_config.get_partition(node, ep)
partition = []
for node in partition_candidate:
# partitioner infra copies constant data across partitions, so it
# is ok if this partition doesn't have it
if is_constant_data(ep, node) and dsj.contains(node):
continue
partition.append(node)

# Union overlaps into a single group
if len(partition) > 0:
dsj.find(partition[0])
for i in range(1, len(partition)):
dsj.union(partition[0], partition[i])

return matched_nodes
return dsj.gen_groups()

def generate_partitions(self, ep: ExportedProgram) -> List[Partition]:
matched_nodes = self.get_matched_nodes_from_configs(ep)
# create partitions
partitions = generate_partitions_from_list_of_nodes(
partitions = generate_grouped_partitions_from_list_of_nodes(
ep.graph_module,
matched_nodes,
)
Expand Down
60 changes: 44 additions & 16 deletions exir/backend/canonical_partitioners/group_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(
)
self.node_to_group = collections.defaultdict(int)
self.all_nodes_in_groups = set()
if node_groups:
if self.node_groups:
for i, group in enumerate(self.node_groups):
for node in group:
# Node is in multiple groups - not allowed
Expand All @@ -101,19 +101,25 @@ def _can_merge_partitions(self, p1, p2, partitions_by_id):
p2_nodes = set(partitions_by_id[p2].nodes.keys())
combined_nodes = p1_nodes.union(p2_nodes)

for node in combined_nodes:
# Get all downstream nodes that are not in the combined partition
external_downstreams = {
n
for n in self.dependency_viewer.downstreams_of(node)
if n not in combined_nodes
}
user_nodes = []
# topologically, p2_nodes comes before p1_nodes, so we only
# need to check the downstream nodes of p2.
# Additionally, we don't need to check all the downstream nodes
# of p2, we only need to check the nodes directly outside of p2.
# example:
# partition[a --> b --> c] --> d --> e --> f
# we don't need to check [d, e, f] we only need to check [d] because
# the downstream users of [d] will include [e, f]
for node in p2_nodes:
for user in node.users:
if user not in combined_nodes:
user_nodes.append(user)

for external_node in user_nodes:
# Check if any external downstream nodes have downstream nodes in the combined partition
for external_node in external_downstreams:
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
if any(n in combined_nodes for n in downstream_nodes):
return False
downstream_nodes = self.dependency_viewer.downstreams_of(external_node)
if any(n in combined_nodes for n in downstream_nodes):
return False

return True

Expand All @@ -133,13 +139,30 @@ def _process_node_groups(
if not self.node_groups:
return group_to_partition_id

for i, group in enumerate(self.node_groups):
# Create a partition for each group
processed_nodes = set()

# We have to create the partitions in reverse topological order
# so we find the groups as we traverse backwards in the graph
# this likely needs to be combined with the process_remaining_nodes
# TODO: this currently doesn't work with _process_remaining_nodes so
# if a user provides grouped nodes with operatorsupport, then this will
# faile
for node in reversed(self.graph_module.graph.nodes):
if node not in self.node_to_group:
continue

if node in processed_nodes:
continue

group_idx = self.node_to_group[node]
group = self.node_groups[group_idx]

# Create a partition for group
partition_id = next(new_partition_id)
partition = Partition(id=partition_id, nodes=set())
partitions_by_id[partition_id] = partition
partitions_order[partition_id] = partition_id
group_to_partition_id[i] = partition_id
group_to_partition_id[group_idx] = partition_id

# Add all supported nodes from the group to the partition
for node in group:
Expand All @@ -164,6 +187,12 @@ def _process_node_groups(
partition_map[partition_id].add(target_id)
partition_map[partition_id].update(partition_map[target_id])

# all the nodes in the group have now been processed
# so skip if we encoutner them again in our rev topo
# iteration
for node in group:
processed_nodes.add(node)

return group_to_partition_id

def _process_remaining_nodes(
Expand Down Expand Up @@ -209,7 +238,6 @@ def _merge_partitions(

# Set to track removed partitions from initial static list so we can skip them
already_merged = set()

# Try to merge each pair of partitions
for i, p1 in enumerate(partition_ids):
# Skip if this partition has been already merged
Expand Down
48 changes: 48 additions & 0 deletions exir/backend/canonical_partitioners/pattern_op_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from typing import List, Optional

import torch

from executorch.exir.backend.canonical_partitioners.group_partitioner import (
GroupBasedPartitioner,
)
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.passes.operator_support import any_chain, OperatorSupportBase
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher
Expand Down Expand Up @@ -56,6 +60,50 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return partition_list


def generate_grouped_partitions_from_list_of_nodes(
graph_module: torch.fx.GraphModule,
pattern_list: Optional[List[List[torch.fx.Node]]] = None,
op_support: Optional[OperatorSupportBase] = None,
) -> List[Partition]:
final_op_support: Optional[OperatorSupportBase] = op_support

if pattern_list is not None:
# Tag all the nodes in these patterns
for node_list in pattern_list:
for node in node_list:
node.meta["match"] = True

class MatchTag(OperatorSupportBase):
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
return node.meta.get("match", False)

final_op_support = (
MatchTag()
if final_op_support is None
else any_chain(final_op_support, MatchTag())
)

assert (
final_op_support is not None
), "Did not give a pattern or OperatorSupportBase instance to partition with"

# Run the CapabilityBasedPartitioner to return the largest possible
# subgraphs containing the nodes with the tags
group_partitioner = GroupBasedPartitioner(
graph_module,
final_op_support,
node_groups=pattern_list,
allows_single_node_partition=True,
)
partition_list = group_partitioner.propose_partitions()

# Remove the metadata field we added
for partition in partition_list:
for node in partition.nodes:
node.meta.pop("match", False)
return partition_list


def generate_pattern_op_partitions(
graph_module: torch.fx.GraphModule,
patterns: Optional[List[torch.fx.Graph]] = None,
Expand Down
Loading