Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/cortex_m/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Cortex-M Backend

WIP. This is a temporary/placeholder backend for Cortex-M CPUs. It is not intended to be used in production, but rather as a proof of concept. Things will change without notice.
21 changes: 21 additions & 0 deletions backends/cortex_m/ops/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")

oncall("executorch")

python_library(
name = "ops",
srcs = [
"operators.py",
],
deps = [
"fbcode//caffe2:torch",
]
)
98 changes: 98 additions & 0 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import (
ops as exir_ops,
) # To provide the implementation of the operators
from torch.library import impl, Library, register_fake

# New operator library with a custom namespace to allow fusion etc.
lib = Library("cortex_m", "DEF")

###
# dequantize_per_tensor
###

lib.define(
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
)

lib.define(
"quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::quantize_per_tensor")
def quantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty_like(input, dtype=dtype)


@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd")
def quantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
The implementation of the quantize_per_tensor operator is the same as the
quantize_per_tensor operator in the edge dialect.
"""
return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
input, scale, zero_point, quant_min, quant_max, dtype
)


###
# dequantize_per_tensor
###

lib.define(
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
)
lib.define(
"dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::dequantize_per_tensor")
def dequantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty_like(input, dtype=torch.float)


@impl(lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
The implementation of the dequantize_per_tensor operator is the same as the
dequantize_per_tensor operator in the edge dialect.
"""
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
input, scale, zero_point, quant_min, quant_max, dtype
)
21 changes: 21 additions & 0 deletions backends/cortex_m/passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "replace_quant_nodes_pass",
srcs = ["replace_quant_nodes_pass.py"],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/backends/cortex_m/ops:ops",
]
)
62 changes: 62 additions & 0 deletions backends/cortex_m/passes/replace_quant_nodes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, Tuple

import executorch.backends.cortex_m.ops.operators # noqa
import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue


class ReplaceQuantNodesPass(ExportPass):
"""
Replace quantize and dequantize nodes with the corresponding
cortex_m.quantize_per_tensor and cortex_m.dequantize_per_tensor nodes.
"""

@staticmethod
def _is_qualified_int8_node(args) -> bool:
return (
args[3] >= torch.iinfo(torch.int8).min # qmin
and args[4] <= torch.iinfo(torch.int8).max # qmax
and args[5] == torch.int8 # dtype
)

def __init__(self):
super().__init__()
self.op_replacements = {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: {
"new_target": exir_ops.edge.cortex_m.quantize_per_tensor.default,
"qualifier": self._is_qualified_int8_node,
},
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: {
"new_target": exir_ops.edge.cortex_m.dequantize_per_tensor.default,
"qualifier": self._is_qualified_int8_node,
},
}

def call_operator(
self,
op: Callable[..., object],
args: Tuple[object, ...],
kwargs: Dict[str, object],
meta: NodeMetadata,
) -> ProxyValue:
assert isinstance(
op, EdgeOpOverload
), "Op must be an EdgeOpOverload. Run this pass after to_edge()."

if op in self.op_replacements and self.op_replacements[op]["qualifier"](args):
return super().call_operator(
self.op_replacements[op]["new_target"],
args,
kwargs,
meta,
)
return super().call_operator(op, args, kwargs, meta)
18 changes: 18 additions & 0 deletions backends/cortex_m/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

python_unittest(
name = "test_replace_quant_nodes",
srcs = ["test_replace_quant_nodes.py"],
deps = [
"//pytorch/ao:torchao", # @manual
"//caffe2:torch",
"//executorch/backends/cortex_m/passes:replace_quant_nodes_pass",
"//executorch/backends/cortex_m/ops:ops",
],
)
Loading
Loading