Skip to content

Commit ee54d95

Browse files
Merge branch 'main' into pt/py313
2 parents cb0ecb4 + aea2784 commit ee54d95

File tree

11 files changed

+438
-27
lines changed

11 files changed

+438
-27
lines changed

backends/cadence/hifi/operators/op_quantized_linear_out.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1010
#include <executorch/backends/cadence/hifi/operators/operators.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/op_quantized_linear.h>
1213
#include <xa_nnlib_kernels_api.h>
1314
#include <xtensa/tie/xt_datacache.h>
1415
#include <algorithm>
@@ -218,7 +219,22 @@ void quantized_linear_out(
218219
int64_t out_zero_point,
219220
__ET_UNUSED const optional<Tensor>& offset,
220221
Tensor& out) {
221-
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
222+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
223+
in.scalar_type() == ::executorch::aten::ScalarType::Short &&
224+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
225+
::impl::generic::native::quantized_linear_out(
226+
ctx,
227+
in,
228+
weight,
229+
bias,
230+
in_zero_point,
231+
weight_zero_point,
232+
out_multiplier,
233+
out_shift,
234+
out_zero_point,
235+
offset,
236+
out);
237+
} else if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
222238
_quantized_linear_asym8u(
223239
in,
224240
weight,
@@ -260,7 +276,22 @@ void quantized_linear_per_tensor_out(
260276
int64_t out_zero_point,
261277
__ET_UNUSED const optional<Tensor>& offset,
262278
Tensor& out) {
263-
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
279+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
280+
in.scalar_type() == ::executorch::aten::ScalarType::Short &&
281+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
282+
::impl::generic::native::quantized_linear_per_tensor_out(
283+
ctx,
284+
in,
285+
weight,
286+
bias,
287+
in_zero_point,
288+
weight_zero_point,
289+
out_multiplier,
290+
out_shift,
291+
out_zero_point,
292+
offset,
293+
out);
294+
} else if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
264295
_quantized_linear_per_tensor_asym8u(
265296
in,
266297
weight,

backends/cadence/hifi/operators/targets.bzl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ OPERATORS = [
8787
"quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out",
8888
"quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out",
8989
"quantized_layer_norm",
90-
"quantized_linear_out",
9190
"quantized_linear_asym8sxasym8s_asym8s_per_tensor_out",
9291
"quantized_linear_asym8uxasym8u_asym8u_per_tensor_out",
9392
"quantized_matmul_out",
@@ -122,3 +121,7 @@ def define_common_targets():
122121
# Define build targets for all operators registered in the tables above.
123122
for op in OPERATORS:
124123
define_operator(op)
124+
125+
# quantized_linear_out and quantized_linear_per_tensor_out needs additional dependency for int16 support
126+
define_operator("quantized_linear_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_linear"])
127+
define_operator("quantized_linear_per_tensor_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_linear"])
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
#include <sys/times.h>
11+
12+
#include <executorch/kernels/test/TestUtil.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
16+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
17+
#include <executorch/runtime/platform/runtime.h>
18+
19+
#include <executorch/backends/cadence/hifi/operators/operators.h>
20+
21+
namespace impl {
22+
namespace HiFi {
23+
namespace native {
24+
namespace {
25+
26+
using ::executorch::aten::Scalar;
27+
using ::executorch::aten::ScalarType;
28+
using ::executorch::aten::Tensor;
29+
using ::executorch::aten::TensorImpl;
30+
using ::executorch::runtime::Error;
31+
using ::executorch::runtime::KernelRuntimeContext;
32+
using ::executorch::runtime::runtime_init;
33+
using ::executorch::runtime::testing::TensorFactory;
34+
using std::optional;
35+
using std::string_view;
36+
37+
class HiFiQuantizedLinearTest : public OperatorTest {
38+
public:
39+
protected:
40+
void quantized_linear_out(
41+
const Tensor& input,
42+
const Tensor& weight,
43+
const Tensor& bias,
44+
int64_t in_zero_point,
45+
const Tensor& weight_zero_point,
46+
const Tensor& out_multiplier,
47+
const Tensor& out_shift,
48+
int64_t out_zero_point,
49+
const optional<Tensor>& offset,
50+
Tensor& output) {
51+
return ::impl::HiFi::native::quantized_linear_out(
52+
context_,
53+
input,
54+
weight,
55+
bias,
56+
in_zero_point,
57+
weight_zero_point,
58+
out_multiplier,
59+
out_shift,
60+
out_zero_point,
61+
offset,
62+
output);
63+
}
64+
65+
void quantized_linear_per_tensor_out(
66+
const Tensor& input,
67+
const Tensor& weight,
68+
const Tensor& bias,
69+
int64_t in_zero_point,
70+
int64_t weight_zero_point,
71+
int64_t out_multiplier,
72+
int64_t out_shift,
73+
int64_t out_zero_point,
74+
const optional<Tensor>& offset,
75+
Tensor& output) {
76+
return ::impl::HiFi::native::quantized_linear_per_tensor_out(
77+
context_,
78+
input,
79+
weight,
80+
bias,
81+
in_zero_point,
82+
weight_zero_point,
83+
out_multiplier,
84+
out_shift,
85+
out_zero_point,
86+
offset,
87+
output);
88+
}
89+
};
90+
91+
// Test quantized_linear_out with int16 activations (asym8s)
92+
TEST_F(HiFiQuantizedLinearTest, QuantizedLinearInt16Test) {
93+
TensorFactory<ScalarType::Short> tf_int16;
94+
TensorFactory<ScalarType::Int> tf_int32;
95+
TensorFactory<ScalarType::Char> tf_int8;
96+
97+
// Simple 2D case: input [2, 3] x weight [4, 3] = output [2, 4]
98+
// Values captured from e2e test with
99+
// CadenceWith16BitLinearActivationsQuantizer
100+
Tensor input =
101+
tf_int16.make({2, 3}, {-28170, -26389, -32768, -31474, -32266, -29076});
102+
Tensor weight = tf_int8.make(
103+
{4, 3}, {1, 87, -128, -114, -59, 44, -1, 127, -12, 44, -46, -29});
104+
Tensor bias = tf_int32.zeros({4});
105+
Tensor output = tf_int16.zeros({2, 4});
106+
107+
int64_t in_zero_point = -29822;
108+
Tensor weight_zero_point = tf_int32.make({1}, {2});
109+
Tensor out_multiplier = tf_int32.make({1}, {2011373824});
110+
Tensor out_shift = tf_int32.make({1}, {-8});
111+
int64_t out_zero_point = -30847;
112+
quantized_linear_out(
113+
input,
114+
weight,
115+
bias,
116+
in_zero_point,
117+
weight_zero_point,
118+
out_multiplier,
119+
out_shift,
120+
out_zero_point,
121+
std::nullopt,
122+
output);
123+
// Expected output from e2e test
124+
Tensor expected_output = tf_int16.make(
125+
{2, 4}, {-28384, -32767, -29144, -30862, -31956, -29486, -31985, -30756});
126+
EXPECT_TENSOR_CLOSE(output, expected_output);
127+
}
128+
129+
} // namespace
130+
} // namespace native
131+
} // namespace HiFi
132+
} // namespace impl

backends/nxp/backend/edge_helper.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,14 @@ def previous_non_qdq_node(node: Node, input_index: int = 0) -> Node | None:
125125
current_node = current_node.args[0]
126126
else:
127127
return current_node
128+
129+
130+
Scale = list[float] | float
131+
ZeroPoint = list[int] | int
132+
133+
134+
def get_quantization_parameters_for(node: Node) -> tuple[Scale, ZeroPoint] | None:
135+
if "quantize" not in node.target.__name__ or len(node.args) < 3:
136+
return None
137+
138+
return node.args[1], node.args[2] # Scale and zero_point
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 NXP
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import numpy as np
7+
import torch
8+
9+
from executorch.backends.nxp.backend.edge_helper import get_quantization_parameters_for
10+
from executorch.backends.nxp.edge_passes.neutron_edge_pass import NeutronEdgePass
11+
from executorch.backends.nxp.neutron_partitioner import QDQClusterRecognizer
12+
from executorch.exir.dialects._ops import ops as exir_ops
13+
from torch.fx.passes.infra.pass_base import PassResult
14+
15+
16+
class RemoveAdditionalQDQClustersPass(NeutronEdgePass):
17+
"""
18+
After delegation of partitions, there may be additional dequantize quantize nodes for QDQ clusters that were
19+
not delegated. If dequantize quantize nodes are quantized per tensor and quantization parameters of dequantize
20+
and quantize nodes in a QDQ cluster are equal, the nodes can be removed and thus the inner nodes computed in int8.
21+
22+
23+
┌────────────▼──────────┐
24+
│ dequantize_per_tensor │
25+
└────────────┬──────────┘
26+
│ │
27+
┌───▼──┐ replace with ┌───▼──┐
28+
│ node │ ──────────────► │ node │
29+
└───┬──┘ └───┬──┘
30+
│ ▼
31+
┌───────────▼─────────┐
32+
│ quantize_per_tensor │
33+
└───────────┬─────────┘
34+
35+
36+
"""
37+
38+
qdq_per_channel_nodes = (
39+
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default,
40+
exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
41+
)
42+
43+
qdq_per_tensor_nodes = (
44+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
45+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
46+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
47+
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor,
48+
)
49+
50+
def run(self, graph_module: torch.fx.GraphModule) -> PassResult:
51+
nodes = list(graph_module.graph.nodes)
52+
qdq_clusterer = QDQClusterRecognizer()
53+
qdq_clusterer.tag_qdq_clusters(nodes)
54+
55+
for cluster in qdq_clusterer.cluster_map.values():
56+
# For now, enable only permute_copy and cat.
57+
if cluster.compute_node.target not in [
58+
exir_ops.edge.aten.permute_copy.default,
59+
exir_ops.edge.aten.cat.default,
60+
]:
61+
continue
62+
63+
# Ensure cluster doesn't contain dequantize/quantize per channel nodes.
64+
if any(
65+
node
66+
for node in cluster.ops
67+
if node.target in self.qdq_per_channel_nodes
68+
):
69+
continue
70+
71+
qdq_nodes = [
72+
node for node in cluster.ops if node.target in self.qdq_per_tensor_nodes
73+
]
74+
75+
qdq_nodes_quant_params = [
76+
get_quantization_parameters_for(node) for node in qdq_nodes
77+
]
78+
79+
equal_quant_scales = [
80+
np.allclose(
81+
qdq_nodes_quant_params[idx][0], qdq_nodes_quant_params[idx + 1][0]
82+
)
83+
for idx in range(len(qdq_nodes_quant_params[:-1]))
84+
]
85+
86+
equal_quant_zero_points = [
87+
np.allclose(
88+
qdq_nodes_quant_params[idx][1], qdq_nodes_quant_params[idx + 1][1]
89+
)
90+
for idx in range(len(qdq_nodes_quant_params[:-1]))
91+
]
92+
93+
# Check if all quantization params are equal to ensure that QDQ cluster can be removed.
94+
if not all(equal_quant_scales + equal_quant_zero_points):
95+
continue
96+
97+
# Replace the uses of each dequantize/quantize node with its arg node.
98+
for qdq_node in qdq_nodes:
99+
qdq_node.replace_all_uses_with(qdq_node.args[0])
100+
graph_module.graph.erase_node(qdq_node)
101+
102+
# Remove compute node cluster info from node meta.
103+
cluster.compute_node.meta.pop("cluster")
104+
105+
graph_module = self.recompile_module(graph_module)
106+
107+
# The graph has now changed, and we cannot keep iterating through it. Return the new graph and the parent
108+
# class will call this pass again.
109+
return PassResult(graph_module, True)
110+
111+
return PassResult(graph_module, False)

backends/nxp/tests/executorch_pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from executorch.backends.nxp.edge_passes.neutron_edge_pass_manager import (
1818
NeutronEdgePassManager,
1919
)
20+
from executorch.backends.nxp.edge_passes.remove_additional_quantize_dequantize_nodes_pass import (
21+
RemoveAdditionalQDQClustersPass,
22+
)
2023
from executorch.backends.nxp.edge_passes.remove_io_quant_ops_pass import (
2124
RemoveIOQuantOpsPass,
2225
)
@@ -35,7 +38,6 @@
3538
from torch.export import export
3639
from torchao.quantization.pt2e.quantizer import Quantizer
3740

38-
3941
neutron_converter_flavor = "SDK_25_09"
4042
neutron_target_spec = NeutronTargetSpec(
4143
target="imxrt700", neutron_converter_flavor=neutron_converter_flavor
@@ -64,7 +66,6 @@ def _get_default_quantizer(target_spec: NeutronTargetSpec) -> Quantizer:
6466
def to_model_input_spec(
6567
input_spec: tuple[ModelInputSpec, ...] | tuple[int, ...] | list[tuple[int, ...]]
6668
) -> tuple[ModelInputSpec, ...]:
67-
6869
if isinstance(input_spec, tuple) and all(
6970
isinstance(spec, ModelInputSpec) for spec in input_spec
7071
):
@@ -139,6 +140,10 @@ def to_quantized_edge_program(
139140
[RemoveIOQuantOpsPass(edge_program_manager=edge_program_manager)]
140141
)
141142

143+
edge_program_manager = edge_program_manager.transform(
144+
NeutronEdgePassManager([RemoveAdditionalQDQClustersPass()])
145+
)
146+
142147
return edge_program_manager
143148

144149

backends/nxp/tests/ir/converter/node_converter/test_permute_copy_converter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def forward(self, x):
104104
return torch.permute(x, self.perm)
105105

106106

107-
class TestPermuteCopyConversion(kgb.SpyAgency, unittest.TestCase):
107+
class TestPermuteCopyConversion(unittest.TestCase):
108108
@classmethod
109109
def setUpClass(cls):
110110
torch.manual_seed(23)
@@ -302,9 +302,9 @@ def test_permute_copy_non_delegated_conversion__from_permute_4D__quantized(
302302
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
303303

304304
nodes = list(edge_program.graph.nodes)
305-
assert len(nodes) == 10
305+
assert len(nodes) == 8
306306
assert (
307-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
307+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
308308
) # PermuteCopy not delegated.
309309

310310
@parameterized.expand(
@@ -320,7 +320,7 @@ def test_permute_copy_non_delegated_conversion__from_transpose_4D__quantized(
320320
edge_program = to_quantized_edge_program(model, input_shape).exported_program()
321321

322322
nodes = list(edge_program.graph.nodes)
323-
assert len(nodes) == 10
323+
assert len(nodes) == 8
324324
assert (
325-
nodes[6].target == exir_ops.edge.aten.permute_copy.default
325+
nodes[5].target == exir_ops.edge.aten.permute_copy.default
326326
) # PermuteCopy not delegated.

0 commit comments

Comments
 (0)