88
99#include < executorch/kernels/quantized/cpu/embeddingxb.h>
1010#include < executorch/runtime/kernel/kernel_includes.h>
11- #include < algorithm>
12- #include < cassert>
1311#include < cinttypes>
14- #include < cmath>
1512
1613namespace torch {
1714namespace executor {
@@ -144,8 +141,9 @@ void check_embedding_xbit_args(
144141 }
145142
146143 ET_CHECK_MSG (
147- indices.scalar_type () == ScalarType::Long,
148- " indices.scalar_type() %" PRId8 " is not Long only Long is supported:" ,
144+ indices.scalar_type () == ScalarType::Long ||
145+ indices.scalar_type () == ScalarType::Int,
146+ " indices.scalar_type() %" PRId8 " is not Long or Int" ,
149147 static_cast <int8_t >(indices.scalar_type ()));
150148
151149 ET_CHECK_MSG (
@@ -166,7 +164,7 @@ void check_embedding_xbit_args(
166164 * Retrieves the embeddings specified by indices, dequantizes them, and stores
167165 * them in out. Weight will always be uint8
168166 */
169- template <typename CTYPE_PARAMS, typename CTYPE_OUT>
167+ template <typename CTYPE_PARAMS, typename CTYPE_OUT, typename CTYPE_INDICES >
170168void embedding_xbit_per_channel (
171169 const Tensor& weight,
172170 const Tensor& weight_scales,
@@ -183,7 +181,7 @@ void embedding_xbit_per_channel(
183181 int32_t group_size = embedding_dim / num_groups_per_channel;
184182
185183 CTYPE_OUT* out_data = out.mutable_data_ptr <CTYPE_OUT>();
186- const int64_t * indices_ptr = indices.const_data_ptr <int64_t >();
184+ const CTYPE_INDICES * indices_ptr = indices.const_data_ptr <CTYPE_INDICES >();
187185
188186 const CTYPE_PARAMS* scales = weight_scales.const_data_ptr <CTYPE_PARAMS>();
189187 const CTYPE_PARAMS* zero_points = nullptr ;
@@ -192,7 +190,7 @@ void embedding_xbit_per_channel(
192190 }
193191
194192 for (int i = 0 ; i < indices.numel (); i++) {
195- int64_t index = indices_ptr[i];
193+ CTYPE_INDICES index = indices_ptr[i];
196194 // If using groupwise embedding
197195 int32_t qparams_index = index * num_groups_per_channel;
198196 CTYPE_PARAMS zp = 0.0 ;
@@ -285,14 +283,17 @@ Tensor& quantized_embedding_xbit_out(
285283 weight_nbit);
286284
287285 constexpr auto name = " quantized_decomposed::embedding_xbit.out" ;
286+ ScalarType indices_type = indices.scalar_type ();
288287 ET_SWITCH_TWO_TYPES (Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
289- embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
290- weight,
291- weight_scales,
292- opt_weight_zero_points,
293- indices,
294- out,
295- weight_nbit);
288+ ET_SWITCH_TWO_TYPES (Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() {
289+ embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT, CTYPE_IDX>(
290+ weight,
291+ weight_scales,
292+ opt_weight_zero_points,
293+ indices,
294+ out,
295+ weight_nbit);
296+ });
296297 });
297298
298299 return out;
@@ -356,15 +357,18 @@ Tensor& quantized_embedding_xbit_dtype_out(
356357 ScalarType out_type = out.scalar_type ();
357358
358359 constexpr auto name = " quantized_decomposed::embedding_xbit.dtype_out" ;
360+ ScalarType indices_type = indices.scalar_type ();
359361 ET_SWITCH_TWO_TYPES (Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
360362 ET_SWITCH_TWO_TYPES (Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
361- embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
362- weight,
363- weight_scales,
364- opt_weight_zero_points,
365- indices,
366- out,
367- weight_nbit);
363+ ET_SWITCH_TWO_TYPES (Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() {
364+ embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT, CTYPE_IDX>(
365+ weight,
366+ weight_scales,
367+ opt_weight_zero_points,
368+ indices,
369+ out,
370+ weight_nbit);
371+ });
368372 });
369373 });
370374
0 commit comments