@@ -17,8 +17,8 @@ using executorch::aten::Tensor;
1717using executorch::runtime::getLeadingDims;
1818using executorch::runtime::KernelRuntimeContext;
1919
20- void quantized_linear_out (
21- KernelRuntimeContext& ctx,
20+ template < typename T>
21+ void inline _typed_quantized_linear (
2222 const Tensor& src,
2323 const Tensor& weight,
2424 const Tensor& bias,
@@ -27,14 +27,11 @@ void quantized_linear_out(
2727 const Tensor& out_multiplier,
2828 const Tensor& out_shift,
2929 int64_t out_zero_point,
30- const executorch::aten::optional<Tensor>& offset,
3130 Tensor& out) {
32- // Assuming uint8_t for now, but needs to be updated for other quantization
33- // types
34- const uint8_t * __restrict__ src_data = src.const_data_ptr <uint8_t >();
35- const uint8_t * __restrict__ weight_data = weight.const_data_ptr <uint8_t >();
31+ const T* __restrict__ src_data = src.const_data_ptr <T>();
32+ const T* __restrict__ weight_data = weight.const_data_ptr <T>();
3633 const int32_t * __restrict__ bias_data = bias.const_data_ptr <int32_t >();
37- uint8_t * __restrict__ out_data = out.mutable_data_ptr <uint8_t >();
34+ T * __restrict__ out_data = out.mutable_data_ptr <T >();
3835
3936 int32_t weight_zero_point = weight_zero_point_t .const_data_ptr <int32_t >()[0 ];
4037
@@ -71,11 +68,50 @@ void quantized_linear_out(
7168 (weight_data[j * N + k] - weight_zero_point);
7269 }
7370 out_data[i * M + j] =
74- kernels::quantize<uint8_t >(sum, out_scale, out_zero_point);
71+ kernels::quantize<T >(sum, out_scale, out_zero_point);
7572 }
7673 }
7774}
7875
76+ void quantized_linear_out (
77+ KernelRuntimeContext& ctx,
78+ const Tensor& src,
79+ const Tensor& weight,
80+ const Tensor& bias,
81+ int64_t src_zero_point,
82+ const Tensor& weight_zero_point_t ,
83+ const Tensor& out_multiplier,
84+ const Tensor& out_shift,
85+ int64_t out_zero_point,
86+ const executorch::aten::optional<Tensor>& offset,
87+ Tensor& out) {
88+ if (out.scalar_type () == executorch::aten::ScalarType::Byte) {
89+ _typed_quantized_linear<uint8_t >(
90+ src,
91+ weight,
92+ bias,
93+ src_zero_point,
94+ weight_zero_point_t ,
95+ out_multiplier,
96+ out_shift,
97+ out_zero_point,
98+ out);
99+ } else if (out.scalar_type () == executorch::aten::ScalarType::Char) {
100+ _typed_quantized_linear<int8_t >(
101+ src,
102+ weight,
103+ bias,
104+ src_zero_point,
105+ weight_zero_point_t ,
106+ out_multiplier,
107+ out_shift,
108+ out_zero_point,
109+ out);
110+ } else {
111+ ET_CHECK_MSG (false , " Unhandled input dtype %hhd" , static_cast <int8_t >(src.scalar_type ()));
112+ }
113+ }
114+
79115}; // namespace native
80116}; // namespace reference
81117}; // namespace impl
0 commit comments