Skip to content

Commit f0edae2

Browse files
Support for int32 indices in sub8b quantized embedding op (#16518)
Summary: Add support for lower bit indices. Differential Revision: D90402567
1 parent 88cfb1d commit f0edae2

File tree

3 files changed

+118
-23
lines changed

3 files changed

+118
-23
lines changed

kernels/quantized/cpu/embeddingxb.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,7 @@
88

99
#include <executorch/kernels/quantized/cpu/embeddingxb.h>
1010
#include <executorch/runtime/kernel/kernel_includes.h>
11-
#include <algorithm>
12-
#include <cassert>
1311
#include <cinttypes>
14-
#include <cmath>
1512

1613
namespace torch {
1714
namespace executor {
@@ -144,8 +141,9 @@ void check_embedding_xbit_args(
144141
}
145142

146143
ET_CHECK_MSG(
147-
indices.scalar_type() == ScalarType::Long,
148-
"indices.scalar_type() %" PRId8 " is not Long only Long is supported:",
144+
indices.scalar_type() == ScalarType::Long ||
145+
indices.scalar_type() == ScalarType::Int,
146+
"indices.scalar_type() %" PRId8 " is not Long or Int",
149147
static_cast<int8_t>(indices.scalar_type()));
150148

151149
ET_CHECK_MSG(
@@ -166,7 +164,7 @@ void check_embedding_xbit_args(
166164
* Retrieves the embeddings specified by indices, dequantizes them, and stores
167165
* them in out. Weight will always be uint8
168166
*/
169-
template <typename CTYPE_PARAMS, typename CTYPE_OUT>
167+
template <typename CTYPE_PARAMS, typename CTYPE_OUT, typename CTYPE_INDICES>
170168
void embedding_xbit_per_channel(
171169
const Tensor& weight,
172170
const Tensor& weight_scales,
@@ -183,7 +181,7 @@ void embedding_xbit_per_channel(
183181
int32_t group_size = embedding_dim / num_groups_per_channel;
184182

185183
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
186-
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
184+
const CTYPE_INDICES* indices_ptr = indices.const_data_ptr<CTYPE_INDICES>();
187185

188186
const CTYPE_PARAMS* scales = weight_scales.const_data_ptr<CTYPE_PARAMS>();
189187
const CTYPE_PARAMS* zero_points = nullptr;
@@ -192,7 +190,7 @@ void embedding_xbit_per_channel(
192190
}
193191

194192
for (int i = 0; i < indices.numel(); i++) {
195-
int64_t index = indices_ptr[i];
193+
CTYPE_INDICES index = indices_ptr[i];
196194
// If using groupwise embedding
197195
int32_t qparams_index = index * num_groups_per_channel;
198196
CTYPE_PARAMS zp = 0.0;
@@ -285,14 +283,17 @@ Tensor& quantized_embedding_xbit_out(
285283
weight_nbit);
286284

287285
constexpr auto name = "quantized_decomposed::embedding_xbit.out";
286+
ScalarType indices_type = indices.scalar_type();
288287
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
289-
embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT>(
290-
weight,
291-
weight_scales,
292-
opt_weight_zero_points,
293-
indices,
294-
out,
295-
weight_nbit);
288+
ET_SWITCH_TWO_TYPES(Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() {
289+
embedding_xbit_per_channel<CTYPE_OUT, CTYPE_OUT, CTYPE_IDX>(
290+
weight,
291+
weight_scales,
292+
opt_weight_zero_points,
293+
indices,
294+
out,
295+
weight_nbit);
296+
});
296297
});
297298

298299
return out;
@@ -356,15 +357,18 @@ Tensor& quantized_embedding_xbit_dtype_out(
356357
ScalarType out_type = out.scalar_type();
357358

358359
constexpr auto name = "quantized_decomposed::embedding_xbit.dtype_out";
360+
ScalarType indices_type = indices.scalar_type();
359361
ET_SWITCH_TWO_TYPES(Float, Half, params_type, ctx, name, CTYPE_P, [&]() {
360362
ET_SWITCH_TWO_TYPES(Float, Half, out_type, ctx, name, CTYPE_OUT, [&]() {
361-
embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT>(
362-
weight,
363-
weight_scales,
364-
opt_weight_zero_points,
365-
indices,
366-
out,
367-
weight_nbit);
363+
ET_SWITCH_TWO_TYPES(Int, Long, indices_type, ctx, name, CTYPE_IDX, [&]() {
364+
embedding_xbit_per_channel<CTYPE_P, CTYPE_OUT, CTYPE_IDX>(
365+
weight,
366+
weight_scales,
367+
opt_weight_zero_points,
368+
indices,
369+
out,
370+
weight_nbit);
371+
});
368372
});
369373
});
370374

kernels/quantized/test/op_embedding2b_test.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,52 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) {
104104
EXPECT_TENSOR_EQ(out, expected);
105105
}
106106

107+
TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingInt32Indices) {
108+
et_pal_init();
109+
TensorFactory<ScalarType::Byte> tfb;
110+
TensorFactory<ScalarType::Float> tf;
111+
TensorFactory<ScalarType::Int> tfi;
112+
113+
int64_t quant_min = -2;
114+
int64_t quant_max = 1;
115+
116+
Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
117+
Tensor weight_zero_points = tf.make({3}, {1, -2, 0});
118+
119+
Tensor qweight = tfb.make({3, 1}, {236, 134, 228});
120+
121+
Tensor indices = tfi.make({3}, {0, 2, 1});
122+
123+
Tensor out = tf.zeros({3, 4});
124+
Tensor expected = tf.make(
125+
{3, 4}, {-1.5, 0.0, -0.5, 0.0, -3.0, -1.5, 0.0, 1.5, 2.0, 1.0, 0.0, 2.0});
126+
127+
quantized_embedding_2bit_out(
128+
qweight,
129+
weight_scales,
130+
weight_zero_points,
131+
quant_min,
132+
quant_max,
133+
indices,
134+
out);
135+
136+
EXPECT_TENSOR_EQ(out, expected);
137+
138+
out = tf.zeros({3, 4});
139+
auto context = KernelRuntimeContext();
140+
torch::executor::native::quantized_embedding_2bit_out(
141+
context,
142+
qweight,
143+
weight_scales,
144+
weight_zero_points,
145+
quant_min,
146+
quant_max,
147+
indices,
148+
out);
149+
150+
EXPECT_TENSOR_EQ(out, expected);
151+
}
152+
107153
TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath1) {
108154
et_pal_init();
109155
TensorFactory<ScalarType::Byte> tfb;

kernels/quantized/test/op_embedding4b_test.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include <executorch/test/utils/DeathTest.h>
1515

1616
#include <gtest/gtest.h>
17-
#include <limits>
1817

1918
using namespace ::testing;
2019
using executorch::aten::ArrayRef;
@@ -101,6 +100,52 @@ TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbedding) {
101100
EXPECT_TENSOR_EQ(out, expected);
102101
}
103102

103+
TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingInt32Indices) {
104+
et_pal_init();
105+
TensorFactory<ScalarType::Byte> tfb;
106+
TensorFactory<ScalarType::Float> tf;
107+
TensorFactory<ScalarType::Int> tfi;
108+
109+
int64_t quant_min = -8;
110+
int64_t quant_max = 7;
111+
112+
Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
113+
Tensor weight_zero_points = tf.make({3}, {1, -5, 0});
114+
115+
Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
116+
117+
Tensor indices = tfi.make({3}, {0, 2, 1});
118+
119+
Tensor out = tf.zeros({3, 4});
120+
Tensor expected = tf.make(
121+
{3, 4}, {-2.0, 0.0, 2.5, 3.0, -12.0, 4.5, -1.5, 9.0, 7.0, 0.0, 1.0, 5.0});
122+
123+
quantized_embedding_4bit_out(
124+
qweight,
125+
weight_scales,
126+
weight_zero_points,
127+
quant_min,
128+
quant_max,
129+
indices,
130+
out);
131+
132+
EXPECT_TENSOR_EQ(out, expected);
133+
134+
out = tf.zeros({3, 4});
135+
auto context = KernelRuntimeContext();
136+
torch::executor::native::quantized_embedding_4bit_out(
137+
context,
138+
qweight,
139+
weight_scales,
140+
weight_zero_points,
141+
quant_min,
142+
quant_max,
143+
indices,
144+
out);
145+
146+
EXPECT_TENSOR_EQ(out, expected);
147+
}
148+
104149
TEST(OpQuantizedEmbedding4bTest, TestGroupWiseQuantizedEmbeddingDeath1) {
105150
et_pal_init();
106151
TensorFactory<ScalarType::Byte> tfb;

0 commit comments

Comments
 (0)