@@ -167,3 +167,207 @@ TEST(OpQuantizedEmbeddingTest, ConsitencyWithReferencePattern) {
167167 EXPECT_TENSOR_EQ (out, fp_out);
168168 EXPECT_TENSOR_EQ (out, expected);
169169}
170+
171+ TEST (OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbedding) {
172+ et_pal_init ();
173+ TensorFactory<ScalarType::Float> tf;
174+ TensorFactory<ScalarType::Int> tf_i;
175+ TensorFactory<ScalarType::Long> tf_l;
176+
177+ int64_t quant_min = 0 ;
178+ int64_t quant_max = 255 ;
179+
180+ Tensor weight_scales = tf.make ({3 }, {0.5 , 1.0 , 1.5 });
181+ Tensor weight_zero_points = tf.make ({3 }, {1 , 5 , 7 });
182+ TensorFactory<ScalarType::Byte> tfo;
183+ Tensor qweight =
184+ tfo.make ({3 , 4 }, {8 , 10 , 12 , 14 , 10 , 12 , 12 , 14 , 8 , 9 , 10 , 12 });
185+
186+ Tensor indices = tf_l.make ({3 }, {0 , 2 , 1 });
187+
188+ Tensor out = tf.zeros ({3 , 4 });
189+ Tensor expected = tf.make (
190+ {3 , 4 }, {3.5 , 4.5 , 5.5 , 6.5 , 1.5 , 3.0 , 4.5 , 7.5 , 5.0 , 7.0 , 7.0 , 9.0 });
191+
192+ quantized_embedding_byte_out (
193+ qweight,
194+ weight_scales,
195+ weight_zero_points,
196+ quant_min,
197+ quant_max,
198+ indices,
199+ out);
200+
201+ EXPECT_TENSOR_EQ (out, expected);
202+
203+ // Groupwise quantization. groupsize = 2
204+ weight_scales = tf.make ({3 , 2 }, {0.5 , 1.0 , 1.5 , 2.0 , 2.5 , 3.0 });
205+ weight_zero_points = tf.make ({3 , 2 }, {1 , 5 , 7 , 9 , 11 , 13 });
206+ /*
207+ fp_weight = [3.5, 4.5, 7, 9,
208+ 4.5, 7.5, 6, 10,
209+ -7.5, -5.0, -9.0, -3.0]
210+ */
211+
212+ out = tf.zeros ({3 , 4 });
213+ expected = tf.make (
214+ {3 , 4 }, {3.5 , 4.5 , 7 , 9 , -7.5 , -5.0 , -9.0 , -3.0 , 4.5 , 7.5 , 6 , 10 });
215+
216+ quantized_embedding_byte_out (
217+ qweight,
218+ weight_scales,
219+ weight_zero_points,
220+ quant_min,
221+ quant_max,
222+ indices,
223+ out);
224+
225+ EXPECT_TENSOR_EQ (out, expected);
226+ }
227+
228+ TEST (OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath1) {
229+ et_pal_init ();
230+ TensorFactory<ScalarType::Float> tf;
231+ TensorFactory<ScalarType::Int> tf_i;
232+ TensorFactory<ScalarType::Long> tf_l;
233+
234+ int64_t quant_min = 0 ;
235+ int64_t quant_max = 255 ;
236+
237+ Tensor weight_scales = tf.make ({4 }, {0.5 , 1.0 , 1.5 , 3.3 });
238+ Tensor weight_zero_points = tf.make ({4 }, {1 , 5 , 7 , 5 });
239+ TensorFactory<ScalarType::Byte> tfo;
240+ Tensor qweight =
241+ tfo.make ({3 , 4 }, {8 , 10 , 12 , 14 , 10 , 12 , 12 , 14 , 8 , 9 , 10 , 12 });
242+
243+ Tensor indices = tf_l.make ({3 }, {0 , 2 , 1 });
244+
245+ Tensor out = tf.zeros ({3 , 4 });
246+ ET_EXPECT_DEATH (
247+ quantized_embedding_byte_out (
248+ qweight,
249+ weight_scales,
250+ weight_zero_points,
251+ quant_min,
252+ quant_max,
253+ indices,
254+ out),
255+ " " );
256+ }
257+
258+ TEST (OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath2) {
259+ et_pal_init ();
260+ TensorFactory<ScalarType::Float> tf;
261+ TensorFactory<ScalarType::Int> tf_i;
262+ TensorFactory<ScalarType::Long> tf_l;
263+
264+ int64_t quant_min = 0 ;
265+ int64_t quant_max = 255 ;
266+
267+ Tensor weight_scales = tf.make ({2 }, {0.5 , 1.0 });
268+ Tensor weight_zero_points = tf.make ({2 }, {1 , 5 });
269+ TensorFactory<ScalarType::Byte> tfo;
270+ Tensor qweight =
271+ tfo.make ({3 , 4 }, {8 , 10 , 12 , 14 , 10 , 12 , 12 , 14 , 8 , 9 , 10 , 12 });
272+
273+ Tensor indices = tf_l.make ({3 }, {0 , 2 , 1 });
274+
275+ Tensor out = tf.zeros ({3 , 4 });
276+ ET_EXPECT_DEATH (
277+ quantized_embedding_byte_out (
278+ qweight,
279+ weight_scales,
280+ weight_zero_points,
281+ quant_min,
282+ quant_max,
283+ indices,
284+ out),
285+ " " );
286+ }
287+
288+ TEST (OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath3) {
289+ et_pal_init ();
290+ TensorFactory<ScalarType::Float> tf;
291+ TensorFactory<ScalarType::Int> tf_i;
292+ TensorFactory<ScalarType::Long> tf_l;
293+
294+ int64_t quant_min = 0 ;
295+ int64_t quant_max = 255 ;
296+
297+ Tensor weight_scales = tf.make ({3 , 2 }, {0.5 , 1.0 , 1.5 , 2.5 , 3.5 , 3.5 });
298+ Tensor weight_zero_points = tf.make ({3 , 2 }, {1 , 5 , 7 , 9 , 11 , 13 });
299+ TensorFactory<ScalarType::Byte> tfo;
300+ Tensor qweight = tfo.make ({3 , 3 }, {8 , 10 , 12 , 14 , 10 , 12 , 12 , 14 , 8 });
301+
302+ Tensor indices = tf_l.make ({3 }, {0 , 2 , 1 });
303+
304+ Tensor out = tf.zeros ({3 , 3 });
305+ ET_EXPECT_DEATH (
306+ quantized_embedding_byte_out (
307+ qweight,
308+ weight_scales,
309+ weight_zero_points,
310+ quant_min,
311+ quant_max,
312+ indices,
313+ out),
314+ " " );
315+ }
316+
317+ TEST (OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath4) {
318+ et_pal_init ();
319+ TensorFactory<ScalarType::Float> tf;
320+ TensorFactory<ScalarType::Int> tf_i;
321+ TensorFactory<ScalarType::Long> tf_l;
322+
323+ int64_t quant_min = 0 ;
324+ int64_t quant_max = 255 ;
325+
326+ Tensor weight_scales = tf.make ({3 , 2 }, {0.5 , 1.0 , 1.5 , 2.5 , 3.5 , 3.5 });
327+ Tensor weight_zero_points = tf.make ({3 }, {1 , 5 , 7 });
328+ TensorFactory<ScalarType::Byte> tfo;
329+ Tensor qweight = tfo.make ({3 , 3 }, {8 , 10 , 12 , 14 , 10 , 12 , 12 , 14 , 8 });
330+
331+ Tensor indices = tf_l.make ({3 }, {0 , 2 , 1 });
332+
333+ Tensor out = tf.zeros ({3 , 3 });
334+ ET_EXPECT_DEATH (
335+ quantized_embedding_byte_out (
336+ qweight,
337+ weight_scales,
338+ weight_zero_points,
339+ quant_min,
340+ quant_max,
341+ indices,
342+ out),
343+ " " );
344+ }
345+
346+ TEST (OpQuantizedEmbeddingTest, TestGroupWiseQuantizedEmbeddingDeath5) {
347+ et_pal_init ();
348+ TensorFactory<ScalarType::Float> tf;
349+ TensorFactory<ScalarType::Int> tf_i;
350+ TensorFactory<ScalarType::Long> tf_l;
351+
352+ int64_t quant_min = 0 ;
353+ int64_t quant_max = 255 ;
354+
355+ Tensor weight_scales = tf.make ({3 , 2 }, {0.5 , 1.0 , 1.5 , 2.5 , 3.5 , 3.5 });
356+ Tensor weight_zero_points = tf.make ({3 , 3 }, {1 , 5 , 7 , 1 , 5 , 7 , 1 , 5 , 7 });
357+ TensorFactory<ScalarType::Byte> tfo;
358+ Tensor qweight = tfo.make ({3 , 3 }, {8 , 10 , 12 , 14 , 10 , 12 , 12 , 14 , 8 });
359+
360+ Tensor indices = tf_l.make ({3 }, {0 , 2 , 1 });
361+
362+ Tensor out = tf.zeros ({3 , 3 });
363+ ET_EXPECT_DEATH (
364+ quantized_embedding_byte_out (
365+ qweight,
366+ weight_scales,
367+ weight_zero_points,
368+ quant_min,
369+ quant_max,
370+ indices,
371+ out),
372+ " " );
373+ }
0 commit comments