Skip to content

Commit 9359481

Browse files
authored
Add support for strongly typed quantized_op_add
Differential Revision: D80570364 Pull Request resolved: #13531
1 parent b427bd7 commit 9359481

9 files changed

+703
-0
lines changed

backends/cadence/aot/functions.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,21 @@
249249
- arg_meta: null
250250
kernel_name: impl::reference::quantized_relu_asym8u_asym8u_per_tensor_out
251251

252+
- func: cadence::quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
253+
kernels:
254+
- arg_meta: null
255+
kernel_name: impl::reference::quantized_add_per_tensor_out
256+
257+
- func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
258+
kernels:
259+
- arg_meta: null
260+
kernel_name: impl::reference::quantized_add_asym8sxasym8s_asym8s_per_tensor_out
261+
262+
- func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
263+
kernels:
264+
- arg_meta: null
265+
kernel_name: impl::reference::quantized_add_asym8uxasym8u_asym8u_per_tensor_out
266+
252267
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
253268
kernels:
254269
- arg_meta: null

backends/cadence/aot/functions_hifi.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,16 @@
404404
- arg_meta: null
405405
kernel_name: cadence::impl::HiFi::quantized_relu_asym8u_asym8u_per_tensor_out
406406

407+
- func: cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
408+
kernels:
409+
- arg_meta: null
410+
kernel_name: cadence::impl::HiFi::quantized_add_asym8sxasym8s_asym8s_per_tensor_out
411+
412+
- func: cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)
413+
kernels:
414+
- arg_meta: null
415+
kernel_name: cadence::impl::HiFi::quantized_add_asym8uxasym8u_asym8u_per_tensor_out
416+
407417
- func: cadence::quantized_matmul.out(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed, *, Tensor(a!) out) -> Tensor(a!)
408418
kernels:
409419
- arg_meta: null

backends/cadence/aot/ops_registrations.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,22 @@
325325
"quantized_add.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
326326
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
327327
)
328+
lib.define(
329+
"quantized_add_asym8sxasym8s_asym8s.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
330+
"int Y_zero_point, float out_scale, int out_zero_point) -> Tensor"
331+
)
332+
lib.define(
333+
"quantized_add_asym8sxasym8s_asym8s.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
334+
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
335+
)
336+
lib.define(
337+
"quantized_add_asym8uxasym8u_asym8u.per_tensor(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
338+
"int Y_zero_point, float out_scale, int out_zero_point) -> Tensor"
339+
)
340+
lib.define(
341+
"quantized_add_asym8uxasym8u_asym8u.per_tensor_out(Tensor X, float X_scale, int X_zero_point, Tensor Y, float Y_scale, "
342+
"int Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
343+
)
328344
lib.define(
329345
"quantized_mul.out(Tensor X, Tensor X_scale, Tensor X_zero_point, Tensor Y, Tensor Y_scale, "
330346
"Tensor Y_zero_point, float out_scale, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
@@ -503,6 +519,36 @@ def quantized_add_per_tensor_meta(
503519
return X.new_empty(out_size, dtype=X.dtype)
504520

505521

522+
@register_fake("cadence::quantized_add_asym8sxasym8s_asym8s.per_tensor")
523+
def quantized_add_asym8sxasym8s_asym8s_per_tensor_meta(
524+
X: torch.Tensor,
525+
X_scale: float,
526+
X_zero_point: int,
527+
Y: torch.Tensor,
528+
Y_scale: float,
529+
Y_zero_point: int,
530+
out_scale: float,
531+
out_zero_point: int,
532+
) -> torch.Tensor:
533+
out_size = torch.broadcast_shapes(X.size(), Y.size())
534+
return X.new_empty(out_size, dtype=X.dtype)
535+
536+
537+
@register_fake("cadence::quantized_add_asym8uxasym8u_asym8u.per_tensor")
538+
def quantized_add_asym8uxasym8u_asym8u_per_tensor_meta(
539+
X: torch.Tensor,
540+
X_scale: float,
541+
X_zero_point: int,
542+
Y: torch.Tensor,
543+
Y_scale: float,
544+
Y_zero_point: int,
545+
out_scale: float,
546+
out_zero_point: int,
547+
) -> torch.Tensor:
548+
out_size = torch.broadcast_shapes(X.size(), Y.size())
549+
return X.new_empty(out_size, dtype=X.dtype)
550+
551+
506552
@register_fake("cadence::quantized_linear")
507553
def quantized_linear_meta(
508554
src: torch.Tensor,

backends/cadence/aot/tests/test_type_dispatch_passes.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,53 @@ def test_uint8_dispatch_quantized_conv_nhwc_dilated(self) -> None:
445445
),
446446
1,
447447
)
448+
449+
def test_int8_dispatch_quantized_add(self) -> None:
450+
"""Test int8 x int8 inputs should dispatch to asym8sxasym8s_asym8s variant for quantized_add"""
451+
x = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
452+
y = torch.randint(-128, 127, (2, 3), dtype=torch.int8)
453+
gm = single_op_builder(
454+
placeholders=(x, y),
455+
op=exir_ops.edge.cadence.quantized_add.per_tensor,
456+
args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0),
457+
)
458+
p = CompileTimeTypeDispatchPass()
459+
gm = cast(PassResult, p(gm)).graph_module
460+
# Original op should be replaced
461+
self.assertEqual(
462+
count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor),
463+
0,
464+
)
465+
# Should be replaced with int8 specific variant
466+
self.assertEqual(
467+
count_node(
468+
gm,
469+
exir_ops.edge.cadence.quantized_add_asym8sxasym8s_asym8s.per_tensor,
470+
),
471+
1,
472+
)
473+
474+
def test_uint8_dispatch_quantized_add(self) -> None:
475+
"""Test uint8 x uint8 inputs should dispatch to asym8uxasym8u_asym8u variant for quantized_add"""
476+
x = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
477+
y = torch.randint(0, 255, (2, 3), dtype=torch.uint8)
478+
gm = single_op_builder(
479+
placeholders=(x, y),
480+
op=exir_ops.edge.cadence.quantized_add.per_tensor,
481+
args=(x, 1.0, 0, y, 1.0, 0, 1.0, 0),
482+
)
483+
p = CompileTimeTypeDispatchPass()
484+
gm = cast(PassResult, p(gm)).graph_module
485+
# Original op should be replaced
486+
self.assertEqual(
487+
count_node(gm, exir_ops.edge.cadence.quantized_add.per_tensor),
488+
0,
489+
)
490+
# Should be replaced with uint8 specific variant
491+
self.assertEqual(
492+
count_node(
493+
gm,
494+
exir_ops.edge.cadence.quantized_add_asym8uxasym8u_asym8u.per_tensor,
495+
),
496+
1,
497+
)

backends/cadence/aot/type_dispatch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ class CompileTimeTypeDispatchPass(ExportPass):
8585
(torch.uint8,): "asym8u_asym8u",
8686
},
8787
),
88+
exir_ops.edge.cadence.quantized_add.per_tensor: OpConfig(
89+
"quantized_add",
90+
type_dispatch_suffixes={
91+
(torch.int8, torch.int8): "asym8sxasym8s_asym8s",
92+
(torch.uint8, torch.uint8): "asym8uxasym8u_asym8u",
93+
},
94+
weight_arg_idx=3,
95+
),
8896
}
8997

9098
def call_operator(
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace cadence {
13+
namespace impl {
14+
namespace HiFi {
15+
namespace native {
16+
17+
using ::executorch::aten::Tensor;
18+
using ::executorch::runtime::KernelRuntimeContext;
19+
20+
void quantized_add_asym8sxasym8s_asym8s_per_tensor_out(
21+
KernelRuntimeContext& ctx,
22+
const Tensor& X,
23+
double X_scale,
24+
int64_t X_zero_point,
25+
const Tensor& Y,
26+
double Y_scale,
27+
int64_t Y_zero_point,
28+
double out_scale,
29+
int64_t out_zero_point,
30+
Tensor& out) {
31+
const int8_t* __restrict__ X_data = X.const_data_ptr<int8_t>();
32+
const int8_t* __restrict__ Y_data = Y.const_data_ptr<int8_t>();
33+
int8_t* __restrict__ out_data = out.mutable_data_ptr<int8_t>();
34+
35+
ssize_t Y_numel = Y.numel();
36+
ssize_t X_numel = X.numel();
37+
ssize_t out_numel = out.numel();
38+
39+
float X_scale_f = static_cast<float>(X_scale);
40+
float Y_scale_f = static_cast<float>(Y_scale);
41+
float out_scale_f = static_cast<float>(out_scale);
42+
int32_t X_zero_point_i32 = static_cast<int32_t>(X_zero_point);
43+
int32_t Y_zero_point_i32 = static_cast<int32_t>(Y_zero_point);
44+
int32_t out_zero_point_i32 = static_cast<int32_t>(out_zero_point);
45+
46+
float inv_out_scale = 1.0f / out_scale_f;
47+
constexpr float min_val =
48+
static_cast<float>(std::numeric_limits<int8_t>::min());
49+
constexpr float max_val =
50+
static_cast<float>(std::numeric_limits<int8_t>::max());
51+
52+
/* Tensor X exactly matches Y in shape, no broadcasting */
53+
if (X_numel == Y_numel && Y_numel == out_numel) {
54+
for (size_t i = 0; i < X_numel; ++i) {
55+
float x = X_scale_f * (X_data[i] - X_zero_point_i32);
56+
float y = Y_scale_f * (Y_data[i] - Y_zero_point_i32);
57+
float z = x + y;
58+
float tmp = roundf(z * inv_out_scale + out_zero_point_i32);
59+
out_data[i] =
60+
static_cast<int8_t>(std::max(std::min(tmp, max_val), min_val));
61+
}
62+
} /* if Y is a scalar Tensor */
63+
else if (Y_numel == 1) {
64+
float y =
65+
kernels::dequantize<int8_t>(Y_data[0], Y_scale_f, Y_zero_point_i32);
66+
for (size_t i = 0; i < X_numel; ++i) {
67+
float x =
68+
kernels::dequantize<int8_t>(X_data[i], X_scale_f, X_zero_point_i32);
69+
float z = x + y;
70+
out_data[i] =
71+
kernels::quantize<int8_t>(z, inv_out_scale, out_zero_point_i32);
72+
}
73+
} /* if X is a scalar Tensor */
74+
else if (X_numel == 1) {
75+
float x =
76+
kernels::dequantize<int8_t>(X_data[0], X_scale_f, X_zero_point_i32);
77+
for (size_t i = 0; i < Y_numel; ++i) {
78+
float y =
79+
kernels::dequantize<int8_t>(Y_data[i], Y_scale_f, Y_zero_point_i32);
80+
float z = x + y;
81+
out_data[i] =
82+
kernels::quantize<int8_t>(z, inv_out_scale, out_zero_point_i32);
83+
}
84+
} /* other broadcasting cases */
85+
else {
86+
/* Broadcasting implementation */
87+
ssize_t X_dim = X.dim();
88+
ssize_t Y_dim = Y.dim();
89+
ssize_t out_dim = out.dim();
90+
91+
/* Precompute strides for X and Y tensors */
92+
constexpr size_t max_dim = executorch::runtime::kTensorDimensionLimit;
93+
size_t X_strides[max_dim] = {0};
94+
size_t Y_strides[max_dim] = {0};
95+
size_t X_stride_val = 1;
96+
size_t Y_stride_val = 1;
97+
98+
/* Calculate strides from last dimension to first */
99+
for (int d = out_dim - 1; d >= 0 && d >= out_dim - max_dim; --d) {
100+
int idx = out_dim - 1 - d; /* Index into the fixed-size array */
101+
if (d >= out_dim - X_dim) {
102+
size_t X_d = d - (out_dim - X_dim);
103+
X_strides[idx] = X_stride_val;
104+
X_stride_val *= X.size(X_d);
105+
}
106+
107+
if (d >= out_dim - Y_dim) {
108+
size_t Y_d = d - (out_dim - Y_dim);
109+
Y_strides[idx] = Y_stride_val;
110+
Y_stride_val *= Y.size(Y_d);
111+
}
112+
}
113+
114+
/* Iterate over output tensor */
115+
for (ssize_t i = 0; i < out_numel; ++i) {
116+
size_t out_idx = i;
117+
size_t X_idx = 0;
118+
size_t Y_idx = 0;
119+
120+
/* Compute corresponding indices in input tensors */
121+
for (int d = out_dim - 1; d >= 0; --d) {
122+
size_t out_dim_idx = out_idx % out.size(d);
123+
out_idx /= out.size(d);
124+
125+
/* Compute X index */
126+
if (d >= out_dim - X_dim) {
127+
size_t X_d = d - (out_dim - X_dim);
128+
size_t X_dim_idx = out_dim_idx % X.size(X_d);
129+
if (d >= out_dim - max_dim) {
130+
int idx = out_dim - 1 - d;
131+
X_idx += X_dim_idx * X_strides[idx];
132+
} else {
133+
size_t X_stride = 1;
134+
for (int k = out_dim - 1; k > d; --k) {
135+
if (k >= out_dim - X_dim) {
136+
size_t X_k = k - (out_dim - X_dim);
137+
X_stride *= X.size(X_k);
138+
}
139+
}
140+
X_idx += X_dim_idx * X_stride;
141+
}
142+
}
143+
144+
/* Compute Y index */
145+
if (d >= out_dim - Y_dim) {
146+
size_t Y_d = d - (out_dim - Y_dim);
147+
size_t Y_dim_idx = out_dim_idx % Y.size(Y_d);
148+
if (d >= out_dim - max_dim) {
149+
int idx = out_dim - 1 - d;
150+
Y_idx += Y_dim_idx * Y_strides[idx];
151+
} else {
152+
size_t Y_stride = 1;
153+
for (int k = out_dim - 1; k > d; --k) {
154+
if (k >= out_dim - Y_dim) {
155+
size_t Y_k = k - (out_dim - Y_dim);
156+
Y_stride *= Y.size(Y_k);
157+
}
158+
}
159+
Y_idx += Y_dim_idx * Y_stride;
160+
}
161+
}
162+
}
163+
164+
/* Apply the operation */
165+
float x = kernels::dequantize<int8_t>(
166+
X_data[X_idx], X_scale_f, X_zero_point_i32);
167+
float y = kernels::dequantize<int8_t>(
168+
Y_data[Y_idx], Y_scale_f, Y_zero_point_i32);
169+
float z = x + y;
170+
out_data[i] =
171+
kernels::quantize<int8_t>(z, inv_out_scale, out_zero_point_i32);
172+
}
173+
}
174+
}
175+
176+
} // namespace native
177+
} // namespace HiFi
178+
} // namespace impl
179+
} // namespace cadence

0 commit comments

Comments
 (0)