@@ -26,53 +26,28 @@ using ::executorch::aten::Tensor;
2626using ::executorch::runtime::getLeadingDims;
2727using ::executorch::runtime::KernelRuntimeContext;
2828
29- void quantized_linear_out (
30- __ET_UNUSED KernelRuntimeContext& ctx,
31- const Tensor& src,
29+
30+ // The nnlib kernel to compute quantized linear via matmul.
31+
32+ void _quantized_linear_asym8u (
33+ const Tensor& in,
3234 const Tensor& weight,
3335 const Tensor& bias,
34- int64_t src_zero_point ,
35- const Tensor& weight_zero_point_t ,
36+ int64_t in_zero_point ,
37+ const Tensor& weight_zero_point ,
3638 const Tensor& out_multiplier,
3739 const Tensor& out_shift,
3840 int64_t out_zero_point,
39- __ET_UNUSED const executorch::aten:: optional<Tensor>& offset,
41+ __ET_UNUSED const optional<Tensor>& offset,
4042 Tensor& out) {
41- int64_t leading_dims = getLeadingDims (src, src.dim () - 1 );
42- int64_t out_dim = weight.size (0 );
43- int64_t in_dim = weight.size (1 );
44-
45- if (out.scalar_type () == executorch::aten::ScalarType::Byte) {
46- const uint8_t * __restrict__ in_data = src.const_data_ptr <uint8_t >();
47- const uint8_t * __restrict__ weight_data = weight.const_data_ptr <uint8_t >();
48- const int32_t * __restrict__ bias_data = bias.const_data_ptr <int32_t >();
49- uint8_t * __restrict__ out_data = out.mutable_data_ptr <uint8_t >();
50-
51- // The nnlib kernel to compute quantized linear via matmul.
52- xa_nn_matmul_asym8uxasym8u_asym8u (
53- out_data,
54- weight_data,
55- in_data,
56- bias_data,
57- out_dim,
58- in_dim,
59- in_dim,
60- leading_dims,
61- in_dim,
62- out_dim,
63- 1 ,
64- -weight_zero_point_t .const_data_ptr <int32_t >()[0 ],
65- -src_zero_point,
66- out_multiplier.const_data_ptr <int32_t >()[0 ],
67- out_shift.const_data_ptr <int32_t >()[0 ],
68- out_zero_point);
69- } else if (out.scalar_type () == executorch::aten::ScalarType::Char) {
70- const int8_t * __restrict__ in_data = src.const_data_ptr <int8_t >();
71- const int8_t * __restrict__ weight_data = weight.const_data_ptr <int8_t >();
72- const int32_t * __restrict__ bias_data = bias.const_data_ptr <int32_t >();
73- int8_t * __restrict__ out_data = out.mutable_data_ptr <int8_t >();
74-
75- xa_nn_matmul_asym8sxasym8s_asym8s (
43+ const int64_t leading_dims = getLeadingDims (in, in.dim () - 1 );
44+ const int64_t out_dim = weight.size (0 ); // = out_dim
45+ const int64_t in_dim = weight.size (1 ); // = in_dim
46+ const uint8_t * __restrict__ in_data = in.const_data_ptr <uint8_t >();
47+ const uint8_t * __restrict__ weight_data = weight.const_data_ptr <uint8_t >();
48+ const int32_t * __restrict__ bias_data = bias.const_data_ptr <int32_t >();
49+ uint8_t * __restrict__ out_data = out.mutable_data_ptr <uint8_t >();
50+ int32_t ret = xa_nn_matmul_asym8uxasym8u_asym8u (
7651 out_data,
7752 weight_data,
7853 in_data,
@@ -84,17 +59,12 @@ void quantized_linear_out(
8459 in_dim,
8560 out_dim,
8661 1 ,
87- - weight_zero_point_t .const_data_ptr <int32_t >()[0 ],
88- -src_zero_point,
62+ -weight_zero_point .const_data_ptr <int32_t >()[0 ], // mat1_zero_bias
63+ -in_zero_point, // mat2_zero_bias
8964 out_multiplier.const_data_ptr <int32_t >()[0 ],
9065 out_shift.const_data_ptr <int32_t >()[0 ],
9166 out_zero_point);
92- } else {
93- ET_CHECK_MSG (
94- false ,
95- " Unhandled input dtype %hhd" ,
96- static_cast <int8_t >(src.scalar_type ()));
97- }
67+ ET_DCHECK_MSG (ret == 0 , " HiFi quantized::linear failed" );
9868}
9969
10070void inline _quantized_linear_asym8s (
0 commit comments