Skip to content

Commit 0617422

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Enable int8 support for quantized_linear reference (#6334)
Summary: As titled. Differential Revision: D64553726
1 parent 6669e18 commit 0617422

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

backends/cadence/reference/operators/quantized_conv_out.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ void quantized_conv_out(
248248
output_scale,
249249
(int8_t)output_zero_point,
250250
per_tensor_quantized);
251+
} else {
252+
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", static_cast<int8_t>(input.scalar_type()));
251253
}
252254
}
253255

backends/cadence/reference/operators/quantized_linear_out.cpp

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ using executorch::aten::Tensor;
1717
using executorch::runtime::getLeadingDims;
1818
using executorch::runtime::KernelRuntimeContext;
1919

20-
void quantized_linear_out(
21-
KernelRuntimeContext& ctx,
20+
template <typename T>
21+
void inline _typed_quantized_linear(
2222
const Tensor& src,
2323
const Tensor& weight,
2424
const Tensor& bias,
@@ -27,14 +27,11 @@ void quantized_linear_out(
2727
const Tensor& out_multiplier,
2828
const Tensor& out_shift,
2929
int64_t out_zero_point,
30-
const executorch::aten::optional<Tensor>& offset,
3130
Tensor& out) {
32-
// Assuming uint8_t for now, but needs to be updated for other quantization
33-
// types
34-
const uint8_t* __restrict__ src_data = src.const_data_ptr<uint8_t>();
35-
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
31+
const T* __restrict__ src_data = src.const_data_ptr<T>();
32+
const T* __restrict__ weight_data = weight.const_data_ptr<T>();
3633
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
37-
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
34+
T* __restrict__ out_data = out.mutable_data_ptr<T>();
3835

3936
int32_t weight_zero_point = weight_zero_point_t.const_data_ptr<int32_t>()[0];
4037

@@ -71,11 +68,50 @@ void quantized_linear_out(
7168
(weight_data[j * N + k] - weight_zero_point);
7269
}
7370
out_data[i * M + j] =
74-
kernels::quantize<uint8_t>(sum, out_scale, out_zero_point);
71+
kernels::quantize<T>(sum, out_scale, out_zero_point);
7572
}
7673
}
7774
}
7875

76+
void quantized_linear_out(
77+
KernelRuntimeContext& ctx,
78+
const Tensor& src,
79+
const Tensor& weight,
80+
const Tensor& bias,
81+
int64_t src_zero_point,
82+
const Tensor& weight_zero_point_t,
83+
const Tensor& out_multiplier,
84+
const Tensor& out_shift,
85+
int64_t out_zero_point,
86+
const executorch::aten::optional<Tensor>& offset,
87+
Tensor& out) {
88+
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
89+
_typed_quantized_linear<uint8_t>(
90+
src,
91+
weight,
92+
bias,
93+
src_zero_point,
94+
weight_zero_point_t,
95+
out_multiplier,
96+
out_shift,
97+
out_zero_point,
98+
out);
99+
} else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
100+
_typed_quantized_linear<int8_t>(
101+
src,
102+
weight,
103+
bias,
104+
src_zero_point,
105+
weight_zero_point_t,
106+
out_multiplier,
107+
out_shift,
108+
out_zero_point,
109+
out);
110+
} else {
111+
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", static_cast<int8_t>(src.scalar_type()));
112+
}
113+
}
114+
79115
}; // namespace native
80116
}; // namespace reference
81117
}; // namespace impl

backends/cadence/reference/operators/quantized_matmul_out.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ void quantized_matmul_out(
144144
out_zero_point,
145145
transposed,
146146
out);
147+
} else {
148+
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", static_cast<int8_t>(X.scalar_type()));
147149
}
148150
}
149151

0 commit comments

Comments
 (0)