Skip to content

Commit 65eee43

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Groupwise quantized embedding (#1864)
Summary: Pull Request resolved: #1864 Implements groupwise quantized embedding lookup within embedding_byte op ghstack-source-id: 214706823 Reviewed By: mikekgfb Differential Revision: D53499639 fbshipit-source-id: 9ec0e58789e00b3944c67d44a53d31a5e897696f
1 parent 7012fa8 commit 65eee43

File tree

2 files changed

+270
-4
lines changed

2 files changed

+270
-4
lines changed

kernels/quantized/cpu/op_embedding.cpp

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,36 @@ void check_embedding_byte_args(
3232
const int64_t weight_quant_max,
3333
const Tensor& indices,
3434
Tensor& out) {
35+
ET_CHECK_MSG(
36+
weight.dim() == 2, "weight must be 2D but got() %zd dims", weight.dim());
37+
38+
ET_CHECK_MSG(
39+
weight_scales.dim() <= 2,
40+
"weight_scales must be 1D or 2D but got() %zd dims",
41+
weight_scales.dim());
42+
43+
auto weight_scales_size = weight_scales.size(0);
44+
45+
ET_CHECK_MSG(
46+
weight_scales_size == weight.size(0),
47+
"Number of scales must be == weight.size(0)=%zd"
48+
", but got %zd",
49+
weight_scales_size,
50+
weight.size(0));
51+
52+
if (weight_scales_size >= weight.size(0)) {
53+
if (weight_scales.dim() == 2) {
54+
auto num_groups = weight_scales.size(1);
55+
auto remainder = weight.size(1) % num_groups;
56+
ET_CHECK_MSG(
57+
remainder == 0,
58+
"Number of groups must divide weight.size(1)=%zd"
59+
", but got # of groups = %zd",
60+
weight.size(1),
61+
num_groups);
62+
}
63+
}
64+
3565
ET_CHECK_MSG(
3666
weight.scalar_type() == ScalarType::Byte ||
3767
weight.scalar_type() == ScalarType::Char,
@@ -50,11 +80,29 @@ void check_embedding_byte_args(
5080
static_cast<int8_t>(weight_scales.scalar_type()));
5181

5282
if (opt_weight_zero_points.has_value()) {
83+
ET_CHECK_MSG(
84+
opt_weight_zero_points.value().dim() == weight_scales.dim(),
85+
"weight_zero_points's rank match that of weight_scales. "
86+
"weight_zero_points rank: %" PRId8 ", weight_scales rank: %" PRId8,
87+
static_cast<int8_t>(opt_weight_zero_points.value().dim()),
88+
static_cast<int8_t>(weight_scales.dim()));
89+
5390
ET_CHECK_MSG(
5491
opt_weight_zero_points.value().scalar_type() == out.scalar_type(),
5592
"weight zero points scalar type %" PRId8
5693
" does not match out.scalar_type()",
5794
static_cast<int8_t>(opt_weight_zero_points.value().scalar_type()));
95+
96+
for (int32_t i = 0; i < weight_scales.dim(); ++i) {
97+
ET_CHECK_MSG(
98+
opt_weight_zero_points.value().size(i) == weight_scales.size(i),
99+
"Dimension size misatch at dim %" PRId8
100+
"Weight_zero_point size = %zd"
101+
", weight_scales size = %zd.",
102+
i,
103+
opt_weight_zero_points.value().size(i),
104+
weight_scales.size(i));
105+
}
58106
}
59107

60108
ET_CHECK_MSG(
@@ -81,10 +129,16 @@ void embedding_byte_per_channel(
81129
const optional<Tensor>& opt_weight_zero_points,
82130
const Tensor& indices,
83131
Tensor& out) {
84-
// An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a weight
85-
// of shape (num_embeddings, embedding_dim).
132+
// An embedding layer nn.Embedding(num_embeddings, embedding_dim) has a
133+
// weight of shape (num_embeddings, embedding_dim).
86134
auto embedding_dim = weight.size(1);
87135

136+
int32_t num_groups_per_channel = 1;
137+
if (weight_scales.dim() == 2) {
138+
num_groups_per_channel = weight_scales.size(1);
139+
}
140+
int32_t group_size = weight.size(1) / num_groups_per_channel;
141+
88142
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
89143
const int64_t* indices_ptr = indices.const_data_ptr<int64_t>();
90144

@@ -96,16 +150,24 @@ void embedding_byte_per_channel(
96150

97151
for (int i = 0; i < indices.numel(); i++) {
98152
int64_t index = indices_ptr[i];
153+
// If using groupwise embedding
154+
int32_t qparams_index = index * num_groups_per_channel;
99155
CTYPE_OUT zp = 0.0;
156+
const CTYPE_OUT* scale_ptr = scales + qparams_index;
157+
const CTYPE_OUT* zero_points_ptr = nullptr;
100158
if (opt_weight_zero_points.has_value()) {
101-
zp = zero_points[index];
159+
zero_points_ptr = zero_points + qparams_index;
102160
}
103-
CTYPE_OUT scale = scales[index];
104161

105162
const CTYPE_WEIGHT* w_data =
106163
weight.data_ptr<CTYPE_WEIGHT>() + embedding_dim * index;
107164

108165
for (int j = 0; j < embedding_dim; ++j) {
166+
int32_t group_id = j / group_size;
167+
const CTYPE_OUT scale = scale_ptr[group_id];
168+
if (opt_weight_zero_points.has_value()) {
169+
zp = zero_points_ptr[group_id];
170+
}
109171
out_data[j] = static_cast<CTYPE_OUT>(
110172
(static_cast<float>(w_data[j]) - static_cast<float>(zp)) *
111173
static_cast<float>(scale));

kernels/quantized/test/op_embedding_test.cpp

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,207 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
167167
EXPECT_TENSOR_EQ(out, fp_out);
168168
EXPECT_TENSOR_EQ(out, expected);
169169
}
170+
171+
TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
172+
et_pal_init();
173+
TensorFactory<ScalarType::Float> tf;
174+
TensorFactory<ScalarType::Int> tf_i;
175+
TensorFactory<ScalarType::Long> tf_l;
176+
177+
int64_t quant_min = 0;
178+
int64_t quant_max = 255;
179+
180+
Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
181+
Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
182+
TensorFactory<ScalarType::Byte> tfo;
183+
Tensor qweight =
184+
tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
185+
186+
Tensor indices = tf_l.make({3}, {0, 2, 1});
187+
188+
Tensor out = tf.zeros({3, 4});
189+
Tensor expected = tf.make(
190+
{3, 4}, {3.5, 4.5, 5.5, 6.5, 1.5, 3.0, 4.5, 7.5, 5.0, 7.0, 7.0, 9.0});
191+
192+
quantized_embedding_byte_out(
193+
qweight,
194+
weight_scales,
195+
weight_zero_points,
196+
quant_min,
197+
quant_max,
198+
indices,
199+
out);
200+
201+
EXPECT_TENSOR_EQ(out, expected);
202+
203+
// Groupwise quantization. groupsize = 2
204+
weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
205+
weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
206+
/*
207+
fp_weight = [3.5, 4.5, 7, 9,
208+
4.5, 7.5, 6, 10,
209+
-7.5, -5.0, -9.0, -3.0]
210+
*/
211+
212+
out = tf.zeros({3, 4});
213+
expected = tf.make(
214+
{3, 4}, {3.5, 4.5, 7, 9, -7.5, -5.0, -9.0, -3.0, 4.5, 7.5, 6, 10});
215+
216+
quantized_embedding_byte_out(
217+
qweight,
218+
weight_scales,
219+
weight_zero_points,
220+
quant_min,
221+
quant_max,
222+
indices,
223+
out);
224+
225+
EXPECT_TENSOR_EQ(out, expected);
226+
}
227+
228+
TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
229+
et_pal_init();
230+
TensorFactory<ScalarType::Float> tf;
231+
TensorFactory<ScalarType::Int> tf_i;
232+
TensorFactory<ScalarType::Long> tf_l;
233+
234+
int64_t quant_min = 0;
235+
int64_t quant_max = 255;
236+
237+
Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
238+
Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5});
239+
TensorFactory<ScalarType::Byte> tfo;
240+
Tensor qweight =
241+
tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
242+
243+
Tensor indices = tf_l.make({3}, {0, 2, 1});
244+
245+
Tensor out = tf.zeros({3, 4});
246+
ET_EXPECT_DEATH(
247+
quantized_embedding_byte_out(
248+
qweight,
249+
weight_scales,
250+
weight_zero_points,
251+
quant_min,
252+
quant_max,
253+
indices,
254+
out),
255+
"");
256+
}
257+
258+
TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
259+
et_pal_init();
260+
TensorFactory<ScalarType::Float> tf;
261+
TensorFactory<ScalarType::Int> tf_i;
262+
TensorFactory<ScalarType::Long> tf_l;
263+
264+
int64_t quant_min = 0;
265+
int64_t quant_max = 255;
266+
267+
Tensor weight_scales = tf.make({2}, {0.5, 1.0});
268+
Tensor weight_zero_points = tf.make({2}, {1, 5});
269+
TensorFactory<ScalarType::Byte> tfo;
270+
Tensor qweight =
271+
tfo.make({3, 4}, {8, 10, 12, 14, 10, 12, 12, 14, 8, 9, 10, 12});
272+
273+
Tensor indices = tf_l.make({3}, {0, 2, 1});
274+
275+
Tensor out = tf.zeros({3, 4});
276+
ET_EXPECT_DEATH(
277+
quantized_embedding_byte_out(
278+
qweight,
279+
weight_scales,
280+
weight_zero_points,
281+
quant_min,
282+
quant_max,
283+
indices,
284+
out),
285+
"");
286+
}
287+
288+
TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
289+
et_pal_init();
290+
TensorFactory<ScalarType::Float> tf;
291+
TensorFactory<ScalarType::Int> tf_i;
292+
TensorFactory<ScalarType::Long> tf_l;
293+
294+
int64_t quant_min = 0;
295+
int64_t quant_max = 255;
296+
297+
Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
298+
Tensor weight_zero_points = tf.make({3, 2}, {1, 5, 7, 9, 11, 13});
299+
TensorFactory<ScalarType::Byte> tfo;
300+
Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
301+
302+
Tensor indices = tf_l.make({3}, {0, 2, 1});
303+
304+
Tensor out = tf.zeros({3, 3});
305+
ET_EXPECT_DEATH(
306+
quantized_embedding_byte_out(
307+
qweight,
308+
weight_scales,
309+
weight_zero_points,
310+
quant_min,
311+
quant_max,
312+
indices,
313+
out),
314+
"");
315+
}
316+
317+
TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
318+
et_pal_init();
319+
TensorFactory<ScalarType::Float> tf;
320+
TensorFactory<ScalarType::Int> tf_i;
321+
TensorFactory<ScalarType::Long> tf_l;
322+
323+
int64_t quant_min = 0;
324+
int64_t quant_max = 255;
325+
326+
Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
327+
Tensor weight_zero_points = tf.make({3}, {1, 5, 7});
328+
TensorFactory<ScalarType::Byte> tfo;
329+
Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
330+
331+
Tensor indices = tf_l.make({3}, {0, 2, 1});
332+
333+
Tensor out = tf.zeros({3, 3});
334+
ET_EXPECT_DEATH(
335+
quantized_embedding_byte_out(
336+
qweight,
337+
weight_scales,
338+
weight_zero_points,
339+
quant_min,
340+
quant_max,
341+
indices,
342+
out),
343+
"");
344+
}
345+
346+
TEST(OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
347+
et_pal_init();
348+
TensorFactory<ScalarType::Float> tf;
349+
TensorFactory<ScalarType::Int> tf_i;
350+
TensorFactory<ScalarType::Long> tf_l;
351+
352+
int64_t quant_min = 0;
353+
int64_t quant_max = 255;
354+
355+
Tensor weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.5, 3.5, 3.5});
356+
Tensor weight_zero_points = tf.make({3, 3}, {1, 5, 7, 1, 5, 7, 1, 5, 7});
357+
TensorFactory<ScalarType::Byte> tfo;
358+
Tensor qweight = tfo.make({3, 3}, {8, 10, 12, 14, 10, 12, 12, 14, 8});
359+
360+
Tensor indices = tf_l.make({3}, {0, 2, 1});
361+
362+
Tensor out = tf.zeros({3, 3});
363+
ET_EXPECT_DEATH(
364+
quantized_embedding_byte_out(
365+
qweight,
366+
weight_scales,
367+
weight_zero_points,
368+
quant_min,
369+
quant_max,
370+
indices,
371+
out),
372+
"");
373+
}

0 commit comments

Comments
 (0)