@@ -120,6 +120,88 @@ void _q_at_k_gemm(
120120 }
121121}
122122
123+ // Refactor op_dequantize.cpp to avoid code duplication
124+ void dequantize_optimized (
125+ const int8_t * in,
126+ const float scale,
127+ const int8_t zero_point,
128+ float * out,
129+ int64_t quant_min,
130+ int64_t quant_max,
131+ size_t numel) {
132+ size_t i = 0 ;
133+ #if defined(__aarch64__) || defined(__ARM_NEON)
134+ int8x8_t zero_point_vec = vdup_n_s8 (zero_point);
135+ float32x4_t scales = vdupq_n_f32 (static_cast <float >(scale));
136+ constexpr int32_t kVecSize = 16 ;
137+ const size_t num_vecs = numel / kVecSize ;
138+ const int8_t * in_copy = in;
139+ float * out_copy = out;
140+ for (; i < num_vecs; i++) {
141+ int8x16_t in_vec = vld1q_s8 (in_copy);
142+ int16x8_t sub_vec_0_7 = vsubl_s8 (vget_low_s8 (in_vec), zero_point_vec);
143+ int32x4_t sub_vec_0_3 = vmovl_s16 (vget_low_s16 (sub_vec_0_7));
144+ int32x4_t sub_vec_4_7 = vmovl_s16 (vget_high_s16 (sub_vec_0_7));
145+ float32x4_t out_vec_0_3 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_0_3), scales);
146+ float32x4_t out_vec_4_7 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_4_7), scales);
147+
148+ int16x8_t sub_vec_8_15 = vsubl_s8 (vget_high_s8 (in_vec), zero_point_vec);
149+ int32x4_t sub_vec_8_11 = vmovl_s16 (vget_low_s16 (sub_vec_8_15));
150+ int32x4_t sub_vec_12_15 = vmovl_s16 (vget_high_s16 (sub_vec_8_15));
151+ float32x4_t out_vec_8_11 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_8_11), scales);
152+ float32x4_t out_vec_12_15 = vmulq_f32 (vcvtq_f32_s32 (sub_vec_12_15), scales);
153+ vst1q_f32 (out_copy + 0 , out_vec_0_3);
154+ vst1q_f32 (out_copy + 4 , out_vec_4_7);
155+ vst1q_f32 (out_copy + 8 , out_vec_8_11);
156+ vst1q_f32 (out_copy + 12 , out_vec_12_15);
157+ in_copy += kVecSize ;
158+ out_copy += kVecSize ;
159+ }
160+ i = i * kVecSize ;
161+ #endif
162+ for (; i < numel; i++) {
163+ out[i] = (static_cast <int16_t >(in[i]) - static_cast <int16_t >(zero_point)) *
164+ scale;
165+ }
166+ }
167+
168+ void dequantize_per_channel_optimized (
169+ const int8_t * in_data,
170+ const float * scales_data,
171+ const int8_t * zero_points_data,
172+ float * out_data,
173+ int64_t quant_min,
174+ int64_t quant_max,
175+ size_t outer_size,
176+ size_t in_outer_stride,
177+ size_t out_outer_stride,
178+ size_t num_channels,
179+ size_t in_channel_stride,
180+ size_t out_channel_stride,
181+ size_t channel_size,
182+ size_t qparams_stride) {
183+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
184+ // Loop through dim
185+ for (size_t channel_idx = 0 ; channel_idx < num_channels; ++channel_idx) {
186+ const int8_t * in_data_local = in_data + outer_idx * in_outer_stride +
187+ channel_idx * in_channel_stride;
188+ const float scale = *(scales_data + channel_idx * qparams_stride);
189+ const int8_t zero_point =
190+ *(zero_points_data + channel_idx * qparams_stride);
191+ float * out_data_local = out_data + outer_idx * out_outer_stride +
192+ channel_idx * out_channel_stride;
193+ dequantize_optimized (
194+ in_data_local,
195+ scale,
196+ zero_point,
197+ out_data_local,
198+ quant_min,
199+ quant_max,
200+ channel_size);
201+ }
202+ }
203+ }
204+
123205template <typename accum_t >
124206void _qk_at_v_gemm (
125207 const int64_t m,
@@ -134,24 +216,36 @@ void _qk_at_v_gemm(
134216 const accum_t beta) {
135217 if (v_data.dtype == ScalarType::Char) {
136218 if constexpr (std::is_same<accum_t , float >::value) {
137- int a_stride_m_tmp, b_stride_n_tmp;
138- auto kernel = torchao::kernels::cpu::quantized_matmul::
139- get_fp32_a_input_channelwise_8bit_b_f32_c_matmul (
140- m, n, k, false , false , a_stride_m_tmp, b_stride_n_tmp);
141- kernel (
142- m,
219+ std::vector<float > dequantized_v_data (v_data.m * v_data.n );
220+ dequantize_per_channel_optimized (
221+ static_cast <const int8_t *>(v_data.data ),
222+ static_cast <const float *>(v_data.scales ),
223+ static_cast <const int8_t *>(v_data.zero_points ),
224+ dequantized_v_data.data (),
225+ -128 ,
226+ 127 ,
227+ 1 ,
228+ 0 ,
229+ 0 ,
230+ v_data.m ,
231+ v_stride_n,
232+ v_data.n ,
233+ v_data.n ,
234+ v_data.zero_points_stride );
235+ ::executorch::cpublas::gemm (
236+ ::executorch::cpublas::TransposeType::NoTranspose,
237+ ::executorch::cpublas::TransposeType::NoTranspose,
143238 n,
239+ m,
144240 k,
241+ static_cast <accum_t >(1 ),
242+ dequantized_v_data.data(),
243+ v_data.n,
145244 qk_data,
146- qk_stride_m /* lhs_stride_m*/ ,
147- static_cast <const int8_t *>(v_data.data ),
148- v_stride_n /* rhs_stride_n*/ ,
149- o_data,
150- o_stride_m /* out_stride_n*/ ,
151- static_cast <const int8_t *>(v_data.zero_points ),
152- static_cast <const float *>(v_data.scales ),
245+ qk_stride_m,
153246 beta,
154- v_data.zero_points_stride );
247+ o_data,
248+ o_stride_m);
155249 } else {
156250 ET_CHECK_MSG (
157251 false , " Accumulation in dtype other than float not supported yet" );
0 commit comments