Skip to content

Commit d305146

Browse files
committed
Update tests
1 parent 2ceaa36 commit d305146

File tree

2 files changed

+124
-8
lines changed

2 files changed

+124
-8
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.es93;
11+
12+
import org.apache.lucene.index.VectorEncoding;
13+
14+
import java.util.regex.Matcher;
15+
import java.util.regex.Pattern;
16+
17+
import static org.hamcrest.Matchers.closeTo;
18+
19+
public class ES93HnswBinaryQuantizedBFloat16VectorsFormatTests extends ES93HnswBinaryQuantizedVectorsFormatTests {
20+
21+
@Override
22+
boolean useBFloat16() {
23+
return true;
24+
}
25+
26+
@Override
27+
protected VectorEncoding randomVectorEncoding() {
28+
return VectorEncoding.FLOAT32;
29+
}
30+
31+
@Override
32+
public void testEmptyByteVectorData() throws Exception {
33+
// no bytes
34+
}
35+
36+
@Override
37+
public void testMergingWithDifferentByteKnnFields() throws Exception {
38+
// no bytes
39+
}
40+
41+
@Override
42+
public void testByteVectorScorerIteration() throws Exception {
43+
// no bytes
44+
}
45+
46+
@Override
47+
public void testSortedIndexBytes() throws Exception {
48+
// no bytes
49+
}
50+
51+
@Override
52+
public void testMismatchedFields() throws Exception {
53+
// no bytes
54+
}
55+
56+
@Override
57+
public void testRandomBytes() throws Exception {
58+
// no bytes
59+
}
60+
61+
@Override
62+
public void testWriterRamEstimate() throws Exception {
63+
// estimate is different due to bfloat16
64+
}
65+
66+
@Override
67+
public void testSingleVectorCase() throws Exception {
68+
AssertionError err = expectThrows(AssertionError.class, super::testSingleVectorCase);
69+
assertFloatsWithinBounds(err);
70+
}
71+
72+
@Override
73+
public void testRandom() throws Exception {
74+
AssertionError err = expectThrows(AssertionError.class, super::testRandom);
75+
assertFloatsWithinBounds(err);
76+
}
77+
78+
@Override
79+
public void testRandomWithUpdatesAndGraph() throws Exception {
80+
AssertionError err = expectThrows(AssertionError.class, super::testRandomWithUpdatesAndGraph);
81+
assertFloatsWithinBounds(err);
82+
}
83+
84+
@Override
85+
public void testSparseVectors() throws Exception {
86+
AssertionError err = expectThrows(AssertionError.class, super::testSparseVectors);
87+
assertFloatsWithinBounds(err);
88+
}
89+
90+
@Override
91+
public void testVectorValuesReportCorrectDocs() throws Exception {
92+
AssertionError err = expectThrows(AssertionError.class, super::testVectorValuesReportCorrectDocs);
93+
assertFloatsWithinBounds(err);
94+
}
95+
96+
private static final Pattern FLOAT_ASSERTION_FAILURE = Pattern.compile(".*expected:<([0-9.-]+)> but was:<([0-9.-]+)>");
97+
98+
private static void assertFloatsWithinBounds(AssertionError error) {
99+
Matcher m = FLOAT_ASSERTION_FAILURE.matcher(error.getMessage());
100+
if (m.matches() == false) {
101+
throw error; // nothing to do with us, just rethrow
102+
}
103+
104+
// numbers just need to be in the same vicinity
105+
double expected = Double.parseDouble(m.group(1));
106+
double actual = Double.parseDouble(m.group(2));
107+
double allowedError = expected * 0.01; // within 1%
108+
assertThat(error.getMessage(), actual, closeTo(expected, allowedError));
109+
}
110+
}

server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import org.apache.lucene.util.SameThreadExecutorService;
4747
import org.apache.lucene.util.VectorUtil;
4848
import org.elasticsearch.common.logging.LogConfigurator;
49+
import org.elasticsearch.index.codec.vectors.BFloat16;
4950

5051
import java.io.IOException;
5152
import java.util.Arrays;
@@ -68,9 +69,13 @@ public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFor
6869

6970
private KnnVectorsFormat format;
7071

72+
boolean useBFloat16() {
73+
return false;
74+
}
75+
7176
@Override
7277
public void setUp() throws Exception {
73-
format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, random().nextBoolean());
78+
format = new ES93HnswBinaryQuantizedVectorsFormat(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, random().nextBoolean(), useBFloat16());
7479
super.setUp();
7580
}
7681

@@ -137,12 +142,12 @@ public void testSingleVectorCase() throws Exception {
137142
}
138143

139144
public void testLimits() {
140-
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false));
141-
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false));
142-
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false));
143-
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false));
144-
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false));
145-
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false));
145+
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(-1, 20, false, false));
146+
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(0, 20, false, false));
147+
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 0, false, false));
148+
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, -1, false, false));
149+
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(512 + 1, 20, false, false));
150+
expectThrows(IllegalArgumentException.class, () -> new ES93HnswBinaryQuantizedVectorsFormat(20, 3201, false, false));
146151
expectThrows(
147152
IllegalArgumentException.class,
148153
() -> new ES93HnswBinaryQuantizedVectorsFormat(20, 100, false, false, 1, new SameThreadExecutorService())
@@ -189,7 +194,8 @@ public void testSimpleOffHeapSizeImpl(Directory dir, IndexWriterConfig config, b
189194
assertEquals(1L, (long) offHeap.get("vex"));
190195
assertTrue(offHeap.get("veb") > 0L);
191196
if (expectVecOffHeap) {
192-
assertEquals(vector.length * Float.BYTES, (long) offHeap.get("vec"));
197+
int bytes = useBFloat16() ? BFloat16.BYTES : Float.BYTES;
198+
assertEquals(vector.length * bytes, (long) offHeap.get("vec"));
193199
}
194200
}
195201
}

0 commit comments

Comments
 (0)