Skip to content

Commit 98a1012

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Enable int8 support for quantized_linear and quantized_relu reference (#6334)
Summary: As titled. Reviewed By: zonglinpeng Differential Revision: D64553726
1 parent 16b633b commit 98a1012

File tree

7 files changed

+71
-14
lines changed

7 files changed

+71
-14
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def quantized_relu_meta(
185185
out_multiplier: torch.Tensor,
186186
out_shift: torch.Tensor,
187187
) -> torch.Tensor:
188-
return X.new_empty(X.size(), dtype=torch.uint8)
188+
return X.new_empty(X.size(), dtype=X.dtype)
189189

190190

191191
@register_fake("cadence::quantized_matmul")

backends/cadence/hifi/operators/dequantize_per_tensor.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ void dequantize_per_tensor_out(
4444
impl::HiFi::kernels::dequantize<int32_t>(
4545
out_data, input_data, scale, zero_point, numel);
4646
} else {
47-
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type());
47+
ET_CHECK_MSG(
48+
false,
49+
"Unhandled input dtype %hhd",
50+
static_cast<int8_t>(input.scalar_type()));
4851
}
4952
}
5053

backends/cadence/hifi/operators/quantize_per_tensor.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ void quantize_per_tensor_out(
5050
cadence::impl::HiFi::kernels::quantize<int32_t>(
5151
out_data, input_data, 1. / scale, zero_point, numel);
5252
} else {
53-
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", out.scalar_type());
54-
}
53+
ET_CHECK_MSG(
54+
false,
55+
"Unhandled output dtype %hhd",
56+
static_cast<int8_t>(out.scalar_type())); }
5557
}
5658

5759
}; // namespace native

backends/cadence/hifi/operators/quantized_layer_norm.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ void quantized_layer_norm_out(
151151
output_zero_point,
152152
out);
153153
} else {
154-
ET_CHECK_MSG(false, "Unhandled input dtype %hhd", input.scalar_type());
154+
ET_CHECK_MSG(
155+
false,
156+
"Unhandled input dtype %hhd",
157+
static_cast<int8_t>(input.scalar_type()));
155158
}
156159
}
157160

backends/cadence/reference/operators/quantized_conv_out.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@ void quantized_conv_out(
248248
output_scale,
249249
(int8_t)output_zero_point,
250250
per_tensor_quantized);
251+
} else {
252+
ET_CHECK_MSG(
253+
false,
254+
"Unhandled input dtype %hhd",
255+
static_cast<int8_t>(input.scalar_type()));
251256
}
252257
}
253258

backends/cadence/reference/operators/quantized_linear_out.cpp

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ using executorch::aten::Tensor;
1717
using executorch::runtime::getLeadingDims;
1818
using executorch::runtime::KernelRuntimeContext;
1919

20-
void quantized_linear_out(
21-
KernelRuntimeContext& ctx,
20+
template <typename T>
21+
void inline _typed_quantized_linear(
2222
const Tensor& src,
2323
const Tensor& weight,
2424
const Tensor& bias,
@@ -27,14 +27,11 @@ void quantized_linear_out(
2727
const Tensor& out_multiplier,
2828
const Tensor& out_shift,
2929
int64_t out_zero_point,
30-
const executorch::aten::optional<Tensor>& offset,
3130
Tensor& out) {
32-
// Assuming uint8_t for now, but needs to be updated for other quantization
33-
// types
34-
const uint8_t* __restrict__ src_data = src.const_data_ptr<uint8_t>();
35-
const uint8_t* __restrict__ weight_data = weight.const_data_ptr<uint8_t>();
31+
const T* __restrict__ src_data = src.const_data_ptr<T>();
32+
const T* __restrict__ weight_data = weight.const_data_ptr<T>();
3633
const int32_t* __restrict__ bias_data = bias.const_data_ptr<int32_t>();
37-
uint8_t* __restrict__ out_data = out.mutable_data_ptr<uint8_t>();
34+
T* __restrict__ out_data = out.mutable_data_ptr<T>();
3835

3936
int32_t weight_zero_point = weight_zero_point_t.const_data_ptr<int32_t>()[0];
4037

@@ -71,11 +68,53 @@ void quantized_linear_out(
7168
(weight_data[j * N + k] - weight_zero_point);
7269
}
7370
out_data[i * M + j] =
74-
kernels::quantize<uint8_t>(sum, out_scale, out_zero_point);
71+
kernels::quantize<T>(sum, out_scale, out_zero_point);
7572
}
7673
}
7774
}
7875

76+
void quantized_linear_out(
77+
KernelRuntimeContext& ctx,
78+
const Tensor& src,
79+
const Tensor& weight,
80+
const Tensor& bias,
81+
int64_t src_zero_point,
82+
const Tensor& weight_zero_point_t,
83+
const Tensor& out_multiplier,
84+
const Tensor& out_shift,
85+
int64_t out_zero_point,
86+
const executorch::aten::optional<Tensor>& offset,
87+
Tensor& out) {
88+
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
89+
_typed_quantized_linear<uint8_t>(
90+
src,
91+
weight,
92+
bias,
93+
src_zero_point,
94+
weight_zero_point_t,
95+
out_multiplier,
96+
out_shift,
97+
out_zero_point,
98+
out);
99+
} else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
100+
_typed_quantized_linear<int8_t>(
101+
src,
102+
weight,
103+
bias,
104+
src_zero_point,
105+
weight_zero_point_t,
106+
out_multiplier,
107+
out_shift,
108+
out_zero_point,
109+
out);
110+
} else {
111+
ET_CHECK_MSG(
112+
false,
113+
"Unhandled input dtype %hhd",
114+
static_cast<int8_t>(src.scalar_type()));
115+
}
116+
}
117+
79118
}; // namespace native
80119
}; // namespace reference
81120
}; // namespace impl

backends/cadence/reference/operators/quantized_matmul_out.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ void quantized_matmul_out(
144144
out_zero_point,
145145
transposed,
146146
out);
147+
} else {
148+
ET_CHECK_MSG(
149+
false,
150+
"Unhandled input dtype %hhd",
151+
static_cast<int8_t>(X.scalar_type()));
147152
}
148153
}
149154

0 commit comments

Comments
 (0)