Skip to content

Commit 562e94b

Browse files
RahulC7meta-codesync[bot]
authored andcommitted
Enable 16-bit activations and 8 bit weigths in Cadence Quantizer for Matmul (#15929)
Summary: Pull Request resolved: #15929 # Context We continue from D84284794 to add support for 16-bit activations. Note that right now, all though they support 16-bit activations already, it's only if the weights are also 16-bits. To do this, we need to change the way we template some functions. # Current Behavior Right now, we're composing two macros together, the `ET_FORALL_JARVIS_QUANTIZED_TYPES_WITH_INT16` macro: https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h?lines=22-25 and the function macro(`quantized_linear` chosen for example): https://www.internalfb.com/code/fbsource/[9e8c6d8466107f58aa3de1b9e4ec71c49d670a8f]/fbcode/on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/quantized_linear_out.cpp?lines=30-41 so together, it just becomes a switch statement, calling the `quantized_linear` function with the correct template parameter. However, note that it assumes that both the input activations and weights are the same dtype, which is not the case. # This Diff We fix this checking for our datatypes, and calling the functions with the correct data types, as in D86538176. Differential Revision: D86644079
1 parent a78f023 commit 562e94b

File tree

5 files changed

+203
-3
lines changed

5 files changed

+203
-3
lines changed

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,30 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
372372
# Add 16-bit quantizers for LinearPattern
373373
quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16))
374374
super().__init__(quantizers)
375+
376+
377+
class CadenceWith16BitConvActivationsQuantizer(CadenceQuantizer):
378+
"""
379+
Quantizer including A16 conv
380+
"""
381+
382+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
383+
if quantizers is None:
384+
quantizers = []
385+
# Add 16-bit quantizers for Conv patterns
386+
quantizers.append(CadenceAtenQuantizer(Conv1dPattern(), qconfig_A16))
387+
quantizers.append(CadenceAtenQuantizer(Conv2dPattern(), qconfig_A16))
388+
super().__init__(quantizers)
389+
390+
391+
class CadenceWith16BitMatmulActivationsQuantizer(CadenceQuantizer):
392+
"""
393+
Quantizer including A16 matmul
394+
"""
395+
396+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
397+
if quantizers is None:
398+
quantizers = []
399+
# Add 16-bit quantizers for MatmulPattern
400+
quantizers.append(CadenceAtenQuantizer(MatmulPattern(), qconfig_A16))
401+
super().__init__(quantizers)

backends/cadence/hifi/operators/op_quantized_matmul_out.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
11+
#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/operators.h>
1112
#include <stdlib.h>
1213

1314
using executorch::aten::ScalarType;
@@ -192,8 +193,20 @@ void quantized_matmul_out(
192193
size_t leading_dim = X.size(X.dim() - 2);
193194
size_t out_dim = Y.size(Y.dim() - 1 - transposed);
194195
size_t in_dim = X.size(X.dim() - 1);
195-
196-
if (out.scalar_type() == exec_aten::ScalarType::Byte) {
196+
if (out.scalar_type() == exec_aten::ScalarType::Short) {
197+
::impl::generic::native::quantized_matmul_out(
198+
ctx,
199+
X,
200+
X_zero_point,
201+
Y,
202+
Y_zero_point,
203+
bias,
204+
out_multiplier,
205+
out_shift,
206+
out_zero_point,
207+
transposed,
208+
out);
209+
} else if (out.scalar_type() == exec_aten::ScalarType::Byte) {
197210
_typed_quantized_matmul<uint8_t>(
198211
ctx,
199212
X,

backends/cadence/hifi/operators/operators.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,19 @@ void quantized_linear_per_tensor_out(
8383
const ::executorch::aten::optional<::executorch::aten::Tensor>& offset,
8484
::executorch::aten::Tensor& out);
8585

86+
void quantized_matmul_out(
87+
::executorch::runtime::KernelRuntimeContext& ctx,
88+
const ::executorch::aten::Tensor& X,
89+
int64_t X_zero_point,
90+
const ::executorch::aten::Tensor& Y,
91+
int64_t Y_zero_point,
92+
const ::executorch::aten::optional<::executorch::aten::Tensor>& bias,
93+
int64_t out_multiplier,
94+
int64_t out_shift,
95+
int64_t out_zero_point,
96+
bool transposed,
97+
::executorch::aten::Tensor& out);
98+
8699
void quantized_conv2d_nhwc_out(
87100
::executorch::runtime::KernelRuntimeContext& ctx,
88101
const ::executorch::aten::Tensor& input,

backends/cadence/hifi/operators/targets.bzl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,6 @@ OPERATORS = [
9090
"quantized_linear_out",
9191
"quantized_linear_asym8sxasym8s_asym8s_per_tensor_out",
9292
"quantized_linear_asym8uxasym8u_asym8u_per_tensor_out",
93-
"quantized_matmul_out",
9493
"quantized_matmul_asym8sxasym8s_asym8s_out",
9594
"quantized_matmul_asym8uxasym8u_asym8u_out",
9695
"quantized_relu_out",
@@ -122,3 +121,6 @@ def define_common_targets():
122121
# Define build targets for all operators registered in the tables above.
123122
for op in OPERATORS:
124123
define_operator(op)
124+
125+
# quantized_matmul_out needs additional dependency for int16 support
126+
define_operator("quantized_matmul_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:quantize_matmul_out", "fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators:headers",])
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
#include <sys/times.h>
11+
12+
#include <executorch/kernels/test/TestUtil.h>
13+
#include <executorch/runtime/core/error.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
16+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
17+
#include <executorch/runtime/platform/runtime.h>
18+
19+
#include <executorch/backends/cadence/hifi/operators/operators.h>
20+
21+
namespace impl {
22+
namespace HiFi {
23+
namespace native {
24+
namespace {
25+
26+
using ::executorch::aten::Scalar;
27+
using ::executorch::aten::ScalarType;
28+
using ::executorch::aten::Tensor;
29+
using ::executorch::aten::TensorImpl;
30+
using ::executorch::runtime::Error;
31+
using ::executorch::runtime::KernelRuntimeContext;
32+
using ::executorch::runtime::runtime_init;
33+
using ::executorch::runtime::testing::TensorFactory;
34+
35+
class HiFiQuantizedMatmulTest : public OperatorTest {
36+
public:
37+
protected:
38+
void quantized_matmul_out(
39+
const Tensor& X,
40+
int64_t X_zero_point,
41+
const Tensor& Y,
42+
int64_t Y_zero_point,
43+
const std::optional<Tensor>& bias,
44+
int64_t out_multiplier,
45+
int64_t out_shift,
46+
int64_t out_zero_point,
47+
bool transposed,
48+
Tensor& output) {
49+
return ::impl::HiFi::native::quantized_matmul_out(
50+
context_,
51+
X,
52+
X_zero_point,
53+
Y,
54+
Y_zero_point,
55+
bias,
56+
out_multiplier,
57+
out_shift,
58+
out_zero_point,
59+
transposed,
60+
output);
61+
}
62+
};
63+
64+
// Test quantized_matmul_out with int16 activations and int8 weights
65+
TEST_F(HiFiQuantizedMatmulTest, QuantizedMatmulInt16Test) {
66+
TensorFactory<ScalarType::Short> tf_int16;
67+
TensorFactory<ScalarType::Int> tf_int32;
68+
TensorFactory<ScalarType::Char> tf_int8;
69+
70+
// Simple 2D case: X [64, 33] x Y [33, 128] = output [64, 128]
71+
// Using simple values for testing
72+
Tensor X = tf_int16.ones({64, 33});
73+
Tensor Y = tf_int8.ones({33, 128});
74+
// Bias not used
75+
Tensor bias = tf_int32.full({128}, -30);
76+
Tensor output = tf_int16.zeros({64, 128});
77+
78+
int64_t X_zero_point = 0;
79+
int64_t Y_zero_point = 0;
80+
int64_t out_multiplier = 1073741824; // 0.5 * 2^31
81+
int64_t out_shift = 0;
82+
int64_t out_zero_point = 0;
83+
84+
quantized_matmul_out(
85+
X,
86+
X_zero_point,
87+
Y,
88+
Y_zero_point,
89+
bias, // pass bias tensor
90+
out_multiplier,
91+
out_shift,
92+
out_zero_point,
93+
false, // transposed
94+
output);
95+
96+
// Verify the output is correct
97+
// With all ones input and weights, inner dimension is 33
98+
// Matmul result: 33, with out_multiplier = 0.5 * 2^31 (scales by 0.5)
99+
// Expected value: 33 * 0.5 = 16.5 ≈ 16
100+
EXPECT_EQ(output.const_data_ptr<int16_t>()[0], 16);
101+
}
102+
103+
// Test quantized_matmul_out with transposed Y (int16 activations and int8
104+
// weights)
105+
TEST_F(HiFiQuantizedMatmulTest, QuantizedMatmulInt16TransposedTest) {
106+
TensorFactory<ScalarType::Short> tf_int16;
107+
TensorFactory<ScalarType::Int> tf_int32;
108+
TensorFactory<ScalarType::Char> tf_int8;
109+
110+
// Transposed case: X [64, 33] x Y^T [128, 33] = output [64, 128]
111+
Tensor X = tf_int16.ones({64, 33});
112+
Tensor Y = tf_int8.ones({128, 33}); // Transposed
113+
// Bias not used
114+
Tensor bias = tf_int32.full({128}, -30);
115+
Tensor output = tf_int16.zeros({64, 128});
116+
117+
int64_t X_zero_point = 0;
118+
int64_t Y_zero_point = 0;
119+
int64_t out_multiplier = 1073741824; // 0.5 * 2^31
120+
int64_t out_shift = 0;
121+
int64_t out_zero_point = 0;
122+
123+
quantized_matmul_out(
124+
X,
125+
X_zero_point,
126+
Y,
127+
Y_zero_point,
128+
bias, // pass bias tensor
129+
out_multiplier,
130+
out_shift,
131+
out_zero_point,
132+
true, // transposed
133+
output);
134+
135+
// Verify the output is correct
136+
// With all ones input and weights, inner dimension is 33
137+
// Matmul result: 33, with out_multiplier = 0.5 * 2^31 (scales by 0.5)
138+
// Expected value: 33 * 0.5 = 16.5 ≈ 16
139+
EXPECT_EQ(output.const_data_ptr<int16_t>()[0], 16);
140+
}
141+
142+
} // namespace
143+
} // namespace native
144+
} // namespace HiFi
145+
} // namespace impl

0 commit comments

Comments
 (0)