Skip to content

Commit 92f089f

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Enable int8 support for quantized_linear reference
Summary: As titled. Differential Revision: D64553726
1 parent 6669e18 commit 92f089f

File tree

3 files changed

+49
-18
lines changed

3 files changed

+49
-18
lines changed

backends/cadence/reference/operators/quantized_conv_out.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ void quantized_conv_out(
248248
output_scale,
249249
(int8_t)output_zero_point,
250250
per_tensor_quantized);
251+
} else {
252+
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", static_cast<int8_t>(input.scalar_type()));
251253
}
252254
}
253255

backends/cadence/reference/operators/quantized_linear_out.cpp

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,12 @@ using executorch::aten::Tensor;
1717
using executorch::runtime::getLeadingDims;
1818
using executorch::runtime::KernelRuntimeContext;
1919

20-
void quantized_linear_out(
21-
KernelRuntimeContext& ctx,
22-
const Tensor& src,
23-
const Tensor& weight,
24-
const Tensor& bias,
25-
int64_t src_zero_point,
26-
const Tensor& weight_zero_point_t,
27-
const Tensor& out_multiplier,
28-
const Tensor& out_shift,
29-
int64_t out_zero_point,
30-
const executorch::aten::optional<Tensor>& offset,
31-
Tensor& out) {
32-
// Assuming uint8_t for now, but needs to be updated for other quantization
33-
// types
34-
const uint8_t* __restrict__ src_data = src.const_data_ptr<uint8_t>();
35-
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
20+
template <typename T>
21+
void inline _typed_quantized_linear(const Tensor& src, const Tensor& weight, const Tensor& bias, int64_t src_zero_point, const Tensor& weight_zero_point_t, const Tensor& out_multiplier, const Tensor& out_shift, int64_t out_zero_point, Tensor& out) {
22+
const T* __restrict__ src_data = src.const_data_ptr<T>();
23+
const T* __restrict__ weight_data = weight.const_data_ptr<T>();
3624
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
37-
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
25+
T* __restrict__ out_data = out.mutable_data_ptr<T>();
3826

3927
int32_t weight_zero_point = weight_zero_point_t.const_data_ptr<int32_t>()[0];
4028

@@ -71,11 +59,50 @@ void quantized_linear_out(
7159
(weight_data[j * N + k] - weight_zero_point);
7260
}
7361
out_data[i * M + j] =
74-
kernels::quantize<uint8_t>(sum, out_scale, out_zero_point);
62+
kernels::quantize<T>(sum, out_scale, out_zero_point);
7563
}
7664
}
7765
}
7866

67+
void quantized_linear_out(
68+
KernelRuntimeContext& ctx,
69+
const Tensor& src,
70+
const Tensor& weight,
71+
const Tensor& bias,
72+
int64_t src_zero_point,
73+
const Tensor& weight_zero_point_t,
74+
const Tensor& out_multiplier,
75+
const Tensor& out_shift,
76+
int64_t out_zero_point,
77+
const executorch::aten::optional<Tensor>& offset,
78+
Tensor& out) {
79+
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
80+
_typed_quantized_linear<uint8_t>(
81+
src,
82+
weight,
83+
bias,
84+
src_zero_point,
85+
weight_zero_point_t,
86+
out_multiplier,
87+
out_shift,
88+
out_zero_point,
89+
out);
90+
} else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
91+
_typed_quantized_linear<int8_t>(
92+
src,
93+
weight,
94+
bias,
95+
src_zero_point,
96+
weight_zero_point_t,
97+
out_multiplier,
98+
out_shift,
99+
out_zero_point,
100+
out);
101+
} else {
102+
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", static_cast<int8_t>(src.scalar_type()));
103+
}
104+
}
105+
79106
}; // namespace native
80107
}; // namespace reference
81108
}; // namespace impl

backends/cadence/reference/operators/quantized_matmul_out.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ void quantized_matmul_out(
144144
out_zero_point,
145145
transposed,
146146
out);
147+
} else {
148+
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", static_cast<int8_t>(X.scalar_type()));
147149
}
148150
}
149151

0 commit comments

Comments
 (0)