diff --git a/kernels/quantized/cpu/embeddingxb.cpp b/kernels/quantized/cpu/embeddingxb.cpp index 0ad5470c2c3..f642e360abb 100644 --- a/kernels/quantized/cpu/embeddingxb.cpp +++ b/kernels/quantized/cpu/embeddingxb.cpp @@ -8,10 +8,7 @@ #include #include -#include -#include #include -#include namespace torch { namespace executor { @@ -144,8 +141,9 @@ void check_embedding_xbit_args( } ET_CHECK_MSG( - indices.scalar_type() == ScalarType::Long, - "indices.scalar_type() %" PRId8 " is not Long only Long is supported:", + indices.scalar_type() == ScalarType::Long || + indices.scalar_type() == ScalarType::Int, + "indices.scalar_type() %" PRId8 " is not Long or Int", static_cast(indices.scalar_type())); ET_CHECK_MSG( @@ -166,7 +164,7 @@ void check_embedding_xbit_args( * Retrieves the embeddings specified by indices, dequantizes them, and stores * them in out. Weight will always be uint8 */ -template +template void embedding_xbit_per_channel( const Tensor& weight, const Tensor& weight_scales, @@ -183,7 +181,7 @@ void embedding_xbit_per_channel( int32_t group_size = embedding_dim / num_groups_per_channel; CTYPE_OUT* out_data = out.mutable_data_ptr(); - const int64_t* indices_ptr = indices.const_data_ptr(); + const CTYPE_INDICES* indices_ptr = indices.const_data_ptr(); const CTYPE_PARAMS* scales = weight_scales.const_data_ptr(); const CTYPE_PARAMS* zero_points = nullptr; @@ -192,7 +190,7 @@ void embedding_xbit_per_channel( } for (int i = 0; i < indices.numel(); i++) { - int64_t index = indices_ptr[i]; + CTYPE_INDICES index = indices_ptr[i]; // If using groupwise embedding int32_t qparams_index = index * num_groups_per_channel; CTYPE_PARAMS zp = 0.0; @@ -285,14 +283,17 @@ Tensor& quantized_embedding_xbit_out( weight_nbit); constexpr auto name = "quantized_decomposed::embedding_xbit.out"; + ScalarType indices_type = indices.scalar_type(); ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - embedding_xbit_per_channel( - weight, - weight_scales, - opt_weight_zero_points, - indices, - out, - weight_nbit); + ET_SWITCH_TWO_TYPES(Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() { + embedding_xbit_per_channel( + weight, + weight_scales, + opt_weight_zero_points, + indices, + out, + weight_nbit); + }); }); return out; @@ -356,15 +357,18 @@ Tensor& quantized_embedding_xbit_dtype_out( ScalarType out_type = out.scalar_type(); constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out"; + ScalarType indices_type = indices.scalar_type(); ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() { ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() { - embedding_xbit_per_channel( - weight, - weight_scales, - opt_weight_zero_points, - indices, - out, - weight_nbit); + ET_SWITCH_TWO_TYPES(Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() { + embedding_xbit_per_channel( + weight, + weight_scales, + opt_weight_zero_points, + indices, + out, + weight_nbit); + }); }); }); diff --git a/kernels/quantized/test/op_embedding2b_test.cpp b/kernels/quantized/test/op_embedding2b_test.cpp index 66a3b589dde..597492ea7b9 100644 --- a/kernels/quantized/test/op_embedding2b_test.cpp +++ b/kernels/quantized/test/op_embedding2b_test.cpp @@ -104,6 +104,52 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) { EXPECT_TENSOR_EQ(out, expected); } +TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingInt32Indices) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfi; + + int64_t quant_min = -2; + int64_t quant_max = 1; + + Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3}, {1, -2, 0}); + + Tensor qweight = tfb.make({3, 1}, {236, 134, 228}); + + Tensor indices = tfi.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + Tensor expected = tf.make( + {3, 4}, {-1.5, 0.0, -0.5, 0.0, -3.0, -1.5, 0.0, 1.5, 2.0, 1.0, 0.0, 2.0}); + + quantized_embedding_2bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + + out = tf.zeros({3, 4}); + auto context = KernelRuntimeContext(); + torch::executor::native::quantized_embedding_2bit_out( + context, + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath1) { et_pal_init(); TensorFactory tfb; diff --git a/kernels/quantized/test/op_embedding4b_test.cpp b/kernels/quantized/test/op_embedding4b_test.cpp index b8d5c639c7e..4646f189eaf 100644 --- a/kernels/quantized/test/op_embedding4b_test.cpp +++ b/kernels/quantized/test/op_embedding4b_test.cpp @@ -14,7 +14,6 @@ #include #include -#include using namespace ::testing; using executorch::aten::ArrayRef; @@ -101,6 +100,52 @@ TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) { EXPECT_TENSOR_EQ(out, expected); } +TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingInt32Indices) { + et_pal_init(); + TensorFactory tfb; + TensorFactory tf; + TensorFactory tfi; + + int64_t quant_min = -8; + int64_t quant_max = 7; + + Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5}); + Tensor weight_zero_points = tf.make({3}, {1, -5, 0}); + + Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126}); + + Tensor indices = tfi.make({3}, {0, 2, 1}); + + Tensor out = tf.zeros({3, 4}); + Tensor expected = tf.make( + {3, 4}, {-2.0, 0.0, 2.5, 3.0, -12.0, 4.5, -1.5, 9.0, 7.0, 0.0, 1.0, 5.0}); + + quantized_embedding_4bit_out( + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); + + out = tf.zeros({3, 4}); + auto context = KernelRuntimeContext(); + torch::executor::native::quantized_embedding_4bit_out( + context, + qweight, + weight_scales, + weight_zero_points, + quant_min, + quant_max, + indices, + out); + + EXPECT_TENSOR_EQ(out, expected); +} + TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath1) { et_pal_init(); TensorFactory tfb;