Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions exir/backend/backend_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
import logging
from contextlib import contextmanager, nullcontext
from functools import singledispatch
from typing import Generator, List
from typing import Generator, List, Optional

import torch

from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.backend_details import (
BackendDetails,
PreprocessResult,
)
from executorch.exir._serialize._named_data_store import NamedDataStore
from executorch.exir.backend.compile_spec_schema import CompileSpec

from executorch.exir.backend.partitioner import Partitioner, PartitionResult
Expand Down Expand Up @@ -120,6 +124,7 @@ def to_backend(
backend_id=backend_id,
processed_bytes=preprocess_result.processed_bytes,
compile_specs=compile_specs,
named_data_store_output=preprocess_result.data_store_output
)
lowered_module.meta = {
"debug_handle_map": preprocess_result.debug_handle_map
Expand Down
6 changes: 6 additions & 0 deletions exir/backend/backend_details.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput


def enforcedmethod(func):
Expand All @@ -24,6 +25,11 @@ class PreprocessResult:
debug_handle_map: Optional[Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]] = (
None
)
# Data Store output created from NamedDataStore.

# Named Data store contains all the named data that is stored in the PTE file,
# but retrieveable by delegates via the NamedDataMap at runtime.
data_store_output: Optional[NamedDataStoreOutput] = None


"""
Expand Down
56 changes: 56 additions & 0 deletions exir/backend/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,62 @@ python_library(
],
)

python_library(
name = "backend_with_named_data_map",
srcs = [
"backend_with_named_data_map.py",
],
visibility = [
"//executorch/...",
"//executorch/test/...",
],
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch_src",
"//executorch/exir:delegate",
"//executorch/exir:graph_module",
"//executorch/exir:lib",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:print_program",
"//executorch/exir:schema",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/dialects:lib",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pytree:pylib",
"//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib",
],
)

python_unittest(
name = "test_backend_with_named_data_map",
srcs = [
"test_backend_with_named_data_map.py",
],
visibility = [
"//executorch/...",
"//executorch/test/...",
],
deps = [
"//caffe2:torch",
"//caffe2/functorch:functorch_src",
"//executorch/exir:delegate",
"//executorch/exir:graph_module",
"//executorch/exir:lib",
"//executorch/exir:lowered_backend_module",
"//executorch/exir:print_program",
"//executorch/exir:schema",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/backend:partitioner",
"//executorch/exir/dialects:lib",
"//executorch/extension/pybindings:portable_lib", # @manual
"//executorch/extension/pytree:pylib",
":backend_with_named_data_map",
],
)

python_library(
name = "qnn_backend_demo",
srcs = [
Expand Down
114 changes: 114 additions & 0 deletions exir/backend/test/backend_with_named_data_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import final, List, Dict, Tuple

import torch

from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_pattern_op_partitions,
)

from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.graph_module import get_control_flow_submodules
from torch.export import ExportedProgram
from torch.fx.passes.operator_support import OperatorSupportBase
from executorch.exir._serialize._named_data_store import NamedDataStore


# Backend details are final (cannot be subclassed).
@final
class BackendWithNamedDataMap(BackendDetails):
"""
Test Backend for Named Data Map Functionality

This backend returns no processed_bytes, instead it uses
the named data store and serializes the name of the op
as the key and the data as its code value
"""

@staticmethod
def preprocess(
edge_program: ExportedProgram,
compile_specs: List[CompileSpec],
) -> PreprocessResult:
op_codes = {
exir_ops.edge.aten.sin.default: 0,
exir_ops.edge.aten.add.Tensor: 1,
exir_ops.edge.aten.sub.Tensor: 2,
exir_ops.edge.aten.mul.Tensor: 3,
exir_ops.edge.aten.div.Tensor: 4
}
ndm = NamedDataStore()
for node in edge_program.graph.nodes:
if node.op == "call_function":
if node.target in op_codes.keys():
ndm.add_named_data(node.target.__name__, bytes(op_codes[node.target]))


return PreprocessResult(
processed_bytes=bytes(b""),
debug_handle_map={},
data_store_output=ndm.get_named_data_store_output(),
)

class SimpleOperatorSupport(OperatorSupportBase):
def is_node_supported(self, submodules, node:torch.fx.Node) -> bool:
return node.op == "call_function" and node.target in [
exir_ops.edge.aten.sin.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor
]

@final
class BackendWithNDMPartitioner(Partitioner):
def __init__(self) -> None:
self._op_support = SimpleOperatorSupport()
self.backend_id = BackendWithNamedDataMap.__name__

def _partition_gm(self, graph_module: torch.fx.GraphModule, id_start:int = 0) -> Tuple[int, Dict[str, DelegationSpec]]:
partition_tags: Dict[str, DelegationSpec] = {}
partition_list = generate_pattern_op_partitions(
graph_module, op_support=self._op_support
)

num_partitions_in_gm = len(partition_list)
for partition in partition_list:
curr_par_id = partition.id or 0
delegation_tag =f"tag_{curr_par_id + id_start}"
for node in partition.nodes:
node.meta["delegation_tag"] = delegation_tag
delegation_spec = DelegationSpec(self.backend_id, [])
partition_tags[delegation_tag] = delegation_spec

start_idx_for_submodules = num_partitions_in_gm
for _, submodule, _ in get_control_flow_submodules(graph_module):
start_idx_for_submodules, ret_partition_tags = self._partition_gm(
submodule, start_idx_for_submodules
)
partition_tags.update(ret_partition_tags)


return start_idx_for_submodules, partition_tags

def partition(self, edge_program: ExportedProgram) -> PartitionResult:
_, partition_tags = self._partition_gm(edge_program.graph_module)
return PartitionResult(
tagged_exported_program=edge_program,
partition_tags=partition_tags,
)
85 changes: 85 additions & 0 deletions exir/backend/test/test_backend_with_named_data_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.exir import to_edge

from executorch.exir.backend.test.backend_with_named_data_map import (
BackendWithNamedDataMap,
BackendWithNDMPartitioner
)
from executorch.exir.backend.backend_api import to_backend

from torch.testing import FileCheck
from torch.export.exported_program import ExportedProgram

class TestBackendWithNamedDataMap(unittest.TestCase):
def test_lowered_backend_module_has_output(self):
class M(torch.nn.Module):
def forward(self, x):
return x + x

ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
lowered = to_backend(
BackendWithNamedDataMap.__name__, ep.exported_program(), []
)

buffer_entries = lowered.named_data_store_output.buffers
self.assertTrue(len(buffer_entries) == 1)
stored_data = lowered.named_data_store_output.pte_data

self.assertTrue("aten.add.Tensor" in stored_data)
self.assertTrue(buffer_entries[0].buffer == bytes(1))

def test_named_data_with_partitioner(self):
class M(torch.nn.Module):
def forward(self, x):
y = x + x
y = torch.cos(y)
y = y + y
y = torch.sin(y)
return y - y

ep = to_edge(torch.export.export(M(), (torch.randn(1, 2),)))
ep.to_backend(BackendWithNDMPartitioner())

ndm_output = ep._named_data_store.get_named_data_store_output()
buffer_entries = ndm_output.buffers
stored_data =ndm_output.pte_data
self.assertEqual(len(buffer_entries), 3)
self.assertTrue("aten.add.Tensor" in stored_data)
self.assertTrue("aten.sub.Tensor" in stored_data)
self.assertTrue("aten.sin.default" in stored_data)

def test_named_data_with_control_flow(self):
class M(torch.nn.Module):
def true_branch(self, x):
y = x * x
y = torch.cos(y)
return torch.sin(y)

def false_branch(self, x):
return torch.sin(x)

def forward(self, x, y):
z = x/y
z = torch.cond(z > 1, self.true_branch, self.false_branch, [x])
return z - z

ep = to_edge(torch.export.export(M(), (torch.randn(1, 2), torch.randn(1, 2))))
ep.to_backend(BackendWithNDMPartitioner())

ndm_output = ep._named_data_store.get_named_data_store_output()
buffer_entries = ndm_output.buffers
stored_data =ndm_output.pte_data
self.assertEqual(len(buffer_entries), 4)
self.assertTrue("aten.sub.Tensor" in stored_data)
self.assertTrue("aten.div.Tensor" in stored_data)
self.assertTrue("aten.sin.default" in stored_data)
self.assertTrue("aten.mul.Tensor" in stored_data)
12 changes: 12 additions & 0 deletions exir/lowered_backend_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass
from executorch.exir.schema import Program
from executorch.exir._serialize._named_data_store import NamedDataStoreOutput

from executorch.exir.tracer import Value
from torch._library.fake_class_registry import FakeScriptObject
Expand Down Expand Up @@ -62,19 +63,22 @@ class LoweredBackendModule(torch.nn.Module):
CompileSpec
] # A list of backend-specific objects with static metadata to configure the "compilation" process.
_original_exported_program: ExportedProgram # The original EXIR module
_named_data_store_output: Optional[NamedDataStoreOutput] # Named Data serialized by the backend

def __init__(
self,
edge_program: ExportedProgram,
backend_id: str,
processed_bytes: bytes,
compile_specs: List[CompileSpec],
named_data_store_output: Optional[NamedDataStoreOutput] = None,
) -> None:
super().__init__()
self._original_exported_program = edge_program
self._backend_id = backend_id
self._processed_bytes = processed_bytes
self._compile_specs = compile_specs
self._named_data_store_output = named_data_store_output

# pyre-ignore
def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "LoweredBackendModule":
Expand Down Expand Up @@ -133,6 +137,13 @@ def original_module(self) -> ExportedProgram:
Returns the original EXIR module
"""
return self._original_exported_program

@property
def named_data_store_output(self) -> Optional[NamedDataStoreOutput]:
"""
Returns the Named Data Store Output
"""
return self._named_data_store_output

# TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api
def buffer(
Expand All @@ -154,6 +165,7 @@ def buffer(
segment_alignment=segment_alignment,
constant_tensor_alignment=constant_tensor_alignment,
delegate_alignment=delegate_alignment,
named_data=self.named_data_store_output,
)
)
return out
Expand Down
Loading
Loading