Skip to content

Commit 0287db7

Browse files
committed
fix test
1 parent 5360d70 commit 0287db7

File tree

1 file changed

+26
-22
lines changed

1 file changed

+26
-22
lines changed

kernels/quantized/test/op_embedding2b_test.cpp

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,19 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) {
3636
int64_t quant_max = 1;
3737

3838
Tensor weight_scales = tf.make({3}, {0.5, 1.0, 1.5});
39-
Tensor weight_zero_points = tf.make({3}, {1, -5, 0});
39+
Tensor weight_zero_points = tf.make({3}, {1, -2, 0});
4040

41-
// -3, 1, 6, 7,
42-
// 2, -5, -4, 0,
43-
// -8, 3, -1, 6,
41+
// -2, 1, 0, 1, -> 0, 3, 2, 3 -> 00 11 10 11 -> 59
42+
// 0, -1, -2, 0, -> 2, 1, 0, 2 -> 10 01 00 10 -> 146
43+
// -2, -1, 0, 1, -> 0, 1, 2, 3 -> 00 01 10 11 -> 27
4444

45-
Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
45+
Tensor qweight = tfb.make({3, 1}, {59, 146, 27});
4646

4747
Tensor indices = tfl.make({3}, {0, 2, 1});
4848

4949
Tensor out = tf.zeros({3, 4});
50-
Tensor expected = tf.make(
51-
{3, 4}, {-2.0, 0.0, 2.5, 3.0, -12.0, 4.5, -1.5, 9.0, 7.0, 0.0, 1.0, 5.0});
50+
Tensor expected = tf.make(
51+
{3, 4}, {-1.5, 0.0, -0.5, 0.0, -3.0, -1.5, 0.0, 1.5, -2.0, -3.0, -4.0, -2.0});
5252

5353
quantized_embedding_2bit_out(
5454
qweight,
@@ -76,18 +76,21 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbedding) {
7676
EXPECT_TENSOR_EQ(out, expected);
7777

7878
// Groupwise quantization. groupsize = 2
79+
7980
weight_scales = tf.make({3, 2}, {0.5, 1.0, 1.5, 2.0, 2.5, 3.0});
80-
weight_zero_points = tf.make({3, 2}, {1, -5, 0, 2, -3, -1});
81-
/*
82-
fp_weight = [-2.0, 0.0, 11.0, 12.0,
83-
3.0, -7.5, -12.0, -4.0,
84-
-12.5, 15.0, 0.0, 21.0]
85-
*/
81+
weight_zero_points = tf.make({3, 2}, {1, -2, 0, 1, -2, -1});
8682

87-
out = tf.zeros({3, 4});
88-
expected = tf.make(
89-
{3, 4},
90-
{-2.0, 0.0, 11.0, 12.0, -12.5, 15.0, 0.0, 21.0, 3.0, -7.5, -12.0, -4.0});
83+
// -2, 1, 0, 1, -> 0, 3, 2, 3 -> 00 11 10 11 -> 59
84+
// 0, -1, -2, 0, -> 2, 1, 0, 2 -> 10 01 00 10 -> 146
85+
// -2, -1, 0, 1, -> 0, 1, 2, 3 -> 00 01 10 11 -> 27
86+
87+
Tensor qweight = tfb.make({3, 1}, {59, 146, 27});
88+
89+
Tensor indices = tfl.make({3}, {0, 2, 1});
90+
91+
Tensor out = tf.zeros({3, 4});
92+
Tensor expected = tf.make(
93+
{3, 4}, {-1.5, 0.0, -2.0, -1.0, 0.0, 2.5, 3.0, 6.0, 0.0, -1.5, -6.0, -2.0});
9194

9295
quantized_embedding_2bit_out(
9396
qweight,
@@ -111,11 +114,11 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath1) {
111114
int64_t quant_max = 1;
112115

113116
Tensor weight_scales = tf.make({4}, {0.5, 1.0, 1.5, 3.3});
114-
Tensor weight_zero_points = tf.make({4}, {1, 5, 7, 5});
115-
Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
117+
Tensor weight_zero_points = tf.make({4}, {1, -2, 1, 0});
118+
Tensor qweight = tfb.make({3, 1}, {59, 146, 27});
116119
Tensor indices = tfl.make({3}, {0, 2, 1});
117-
118120
Tensor out = tf.zeros({3, 4});
121+
119122
ET_EXPECT_DEATH(
120123
quantized_embedding_2bit_out(
121124
qweight,
@@ -138,9 +141,10 @@ TEST(OpQuantizedEmbedding2bTest, TestGroupWiseQuantizedEmbeddingDeath2) {
138141
int64_t quant_max = 1;
139142

140143
Tensor weight_scales = tf.make({2}, {0.5, 1.0});
141-
Tensor weight_zero_points = tf.make({2}, {1, 5});
142-
Tensor qweight = tfb.make({3, 2}, {89, 239, 163, 72, 11, 126});
144+
Tensor weight_zero_points = tf.make({2}, {1, -2});
145+
Tensor qweight = tfb.make({3, 1}, {59, 146, 27});
143146
Tensor indices = tfl.make({3}, {0, 2, 1});
147+
Tensor out = tf.zeros({3, 4});
144148

145149
Tensor out = tf.zeros({3, 4});
146150
ET_EXPECT_DEATH(

0 commit comments

Comments
 (0)