Skip to content

Commit 30ee8bd

Browse files
digantdesaifacebook-github-bot
authored andcommitted
init
Summary: Just a placeholder q/dq AoT ops in a new namespace with a test Differential Revision: D72987759
1 parent fc01661 commit 30ee8bd

File tree

7 files changed

+413
-0
lines changed

7 files changed

+413
-0
lines changed

backends/cortex_m/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Cortex-M Backend
2+
3+
WIP. This is a temporary backend for Cortex-M CPUs. It is not intended to be used in production, but rather as a proof of concept. Things will change without notice.

backends/cortex_m/ops/TARGETS

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
8+
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
9+
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")
10+
11+
oncall("executorch")
12+
13+
python_library(
14+
name = "ops",
15+
srcs = [
16+
"operators.py",
17+
],
18+
deps = [
19+
"fbcode//caffe2:torch",
20+
]
21+
)

backends/cortex_m/ops/operators.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
from torch.library import impl, Library, register_fake
3+
from executorch.exir.dialects._ops import (
4+
ops as exir_ops,
5+
) # To provide the implementation of the operators
6+
7+
# New operator library with a custom namespace to allow fusion etc.
8+
lib = Library("cortex_m", "DEF")
9+
10+
###
11+
# dequantize_per_tensor
12+
###
13+
14+
lib.define(
15+
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
16+
)
17+
18+
lib.define(
19+
"quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
20+
)
21+
22+
@register_fake("cortex_m::quantize_per_tensor")
23+
def quantize_per_tensor_meta(
24+
input: torch.Tensor,
25+
scale: float,
26+
zero_point: int,
27+
quant_min: int,
28+
quant_max: int,
29+
dtype: torch.dtype,
30+
) -> torch.Tensor:
31+
return torch.empty_like(input, dtype=dtype)
32+
33+
34+
@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd")
35+
def quantize_per_tensor_impl(
36+
input: torch.Tensor,
37+
scale: float,
38+
zero_point: int,
39+
quant_min: int,
40+
quant_max: int,
41+
dtype: torch.dtype,
42+
) -> torch.Tensor:
43+
"""
44+
The implementation of the quantize_per_tensor operator is the same as the
45+
quantize_per_tensor operator in the edge dialect.
46+
"""
47+
return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
48+
input, scale, zero_point, quant_min, quant_max, dtype
49+
)
50+
51+
52+
###
53+
# dequantize_per_tensor
54+
###
55+
56+
lib.define(
57+
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
58+
)
59+
lib.define(
60+
"dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
61+
)
62+
63+
@register_fake("cortex_m::dequantize_per_tensor")
64+
def dequantize_per_tensor_meta(
65+
input: torch.Tensor,
66+
scale: float,
67+
zero_point: int,
68+
quant_min: int,
69+
quant_max: int,
70+
dtype: torch.dtype,
71+
) -> torch.Tensor:
72+
return torch.empty_like(input, dtype=torch.float)
73+
74+
75+
@impl(lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
76+
def dequantize_per_tensor_impl(
77+
input: torch.Tensor,
78+
scale: float,
79+
zero_point: int,
80+
quant_min: int,
81+
quant_max: int,
82+
dtype: torch.dtype,
83+
) -> torch.Tensor:
84+
"""
85+
The implementation of the dequantize_per_tensor operator is the same as the
86+
dequantize_per_tensor operator in the edge dialect.
87+
"""
88+
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
89+
input, scale, zero_point, quant_min, quant_max, dtype
90+
)

backends/cortex_m/passes/TARGETS

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
oncall("executorch")
4+
5+
python_library(
6+
name = "cortex_m_passes",
7+
srcs = ["replace_quant_nodes_pass.py"],
8+
deps = [
9+
"//caffe2:torch",
10+
"//executorch/exir:lib",
11+
"//executorch/exir:pass_base",
12+
"//executorch/exir/dialects:lib",
13+
"//executorch/backends/cortex_m/ops:ops",
14+
]
15+
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from typing import Callable, Dict, Tuple
2+
import torch
3+
4+
import executorch.backends.cortex_m.ops.operators # noqa
5+
6+
from executorch.exir.dialects._ops import ops as exir_ops
7+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
8+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
9+
10+
11+
class ReplaceQuantNodesPass(ExportPass):
12+
"""
13+
Replace quantize and dequantize nodes with the corresponding
14+
quantize_per_tensor and dequantize_per_tensor nodes.
15+
"""
16+
17+
@staticmethod
18+
def is_qualified_quantize_per_tensor(args) -> bool:
19+
return (
20+
args[3] >= torch.iinfo(torch.int8).min # qmin
21+
and args[4] <= torch.iinfo(torch.int8).max # qmax
22+
and args[5] == torch.int8 # output dtype
23+
)
24+
25+
@staticmethod
26+
def is_qualified_dequantize_per_tensor(args) -> bool:
27+
return (
28+
args[3] >= torch.iinfo(torch.int8).min # qmin
29+
and args[4] <= torch.iinfo(torch.int8).max # qmax
30+
and args[5] == torch.int8 # input dtype
31+
)
32+
33+
def call_operator(
34+
self,
35+
op: Callable[..., object],
36+
args: Tuple[object, ...],
37+
kwargs: Dict[str, object],
38+
meta: NodeMetadata,
39+
) -> ProxyValue:
40+
assert isinstance(
41+
op, EdgeOpOverload
42+
), f"Op must be an EdgeOpOverload, got {type(op)} for op {op}. Try running this pass after to_edge()."
43+
if (
44+
op == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
45+
and self.is_qualified_quantize_per_tensor(args)
46+
):
47+
return super().call_operator(
48+
exir_ops.edge.cortex_m.quantize_per_tensor.default,
49+
args,
50+
kwargs,
51+
meta,
52+
)
53+
elif (
54+
op == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
55+
and self.is_qualified_dequantize_per_tensor(args)
56+
):
57+
return super().call_operator(
58+
exir_ops.edge.cortex_m.dequantize_per_tensor.default,
59+
args,
60+
kwargs,
61+
meta,
62+
)
63+
# For all other operators, pass through unchanged
64+
else:
65+
return super().call_operator(op, args, kwargs, meta)

backends/cortex_m/test/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")
2+
3+
python_unittest(
4+
name = "test_replace_quant_nodes",
5+
srcs = ["test_replace_quant_nodes.py"],
6+
deps = [
7+
"//pytorch/ao:torchao", # @manual
8+
"//caffe2:torch",
9+
"//executorch/backends/cortex_m/passes:cortex_m_passes",
10+
"//executorch/backends/cortex_m/ops:ops",
11+
],
12+
)

0 commit comments

Comments
 (0)