@@ -17,24 +17,12 @@ using executorch::aten::Tensor;
1717using executorch::runtime::getLeadingDims;
1818using executorch::runtime::KernelRuntimeContext;
1919
20- void quantized_linear_out (
21- KernelRuntimeContext& ctx,
22- const Tensor& src,
23- const Tensor& weight,
24- const Tensor& bias,
25- int64_t src_zero_point,
26- const Tensor& weight_zero_point_t ,
27- const Tensor& out_multiplier,
28- const Tensor& out_shift,
29- int64_t out_zero_point,
30- const executorch::aten::optional<Tensor>& offset,
31- Tensor& out) {
32- // Assuming uint8_t for now, but needs to be updated for other quantization
33- // types
34- const uint8_t * __restrict__ src_data = src.const_data_ptr <uint8_t >();
35- const uint8_t * __restrict__ weight_data = weight.const_data_ptr <uint8_t >();
20+ template <typename T>
21+ void inline _typed_quantized_linear (const Tensor& src, const Tensor& weight, const Tensor& bias, int64_t src_zero_point, const Tensor& weight_zero_point_t , const Tensor& out_multiplier, const Tensor& out_shift, int64_t out_zero_point, Tensor& out) {
22+ const T* __restrict__ src_data = src.const_data_ptr <T>();
23+ const T* __restrict__ weight_data = weight.const_data_ptr <T>();
3624 const int32_t * __restrict__ bias_data = bias.const_data_ptr <int32_t >();
37- uint8_t * __restrict__ out_data = out.mutable_data_ptr <uint8_t >();
25+ T * __restrict__ out_data = out.mutable_data_ptr <T >();
3826
3927 int32_t weight_zero_point = weight_zero_point_t .const_data_ptr <int32_t >()[0 ];
4028
@@ -71,11 +59,50 @@ void quantized_linear_out(
7159 (weight_data[j * N + k] - weight_zero_point);
7260 }
7361 out_data[i * M + j] =
74- kernels::quantize<uint8_t >(sum, out_scale, out_zero_point);
62+ kernels::quantize<T >(sum, out_scale, out_zero_point);
7563 }
7664 }
7765}
7866
67+ void quantized_linear_out (
68+ KernelRuntimeContext& ctx,
69+ const Tensor& src,
70+ const Tensor& weight,
71+ const Tensor& bias,
72+ int64_t src_zero_point,
73+ const Tensor& weight_zero_point_t ,
74+ const Tensor& out_multiplier,
75+ const Tensor& out_shift,
76+ int64_t out_zero_point,
77+ const executorch::aten::optional<Tensor>& offset,
78+ Tensor& out) {
79+ if (out.scalar_type () == executorch::aten::ScalarType::Byte) {
80+ _typed_quantized_linear<uint8_t >(
81+ src,
82+ weight,
83+ bias,
84+ src_zero_point,
85+ weight_zero_point_t ,
86+ out_multiplier,
87+ out_shift,
88+ out_zero_point,
89+ out);
90+ } else if (out.scalar_type () == executorch::aten::ScalarType::Char) {
91+ _typed_quantized_linear<int8_t >(
92+ src,
93+ weight,
94+ bias,
95+ src_zero_point,
96+ weight_zero_point_t ,
97+ out_multiplier,
98+ out_shift,
99+ out_zero_point,
100+ out);
101+ } else {
102+ ET_CHECK_MSG (false , " Unhandled input dtype %hhd" , static_cast <int8_t >(src.scalar_type ()));
103+ }
104+ }
105+
79106}; // namespace native
80107}; // namespace reference
81108}; // namespace impl
0 commit comments