Skip to content

Commit 69f52cb

Browse files
committed
Turn the generic format into a proper format
1 parent 5ef4cf1 commit 69f52cb

File tree

6 files changed

+51
-32
lines changed

6 files changed

+51
-32
lines changed

server/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,7 @@
464464
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat,
465465
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat,
466466
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat,
467+
org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat,
467468
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat,
468469
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat;
469470

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormat.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
2626
import org.apache.lucene.index.SegmentReadState;
2727
import org.apache.lucene.index.SegmentWriteState;
28+
import org.elasticsearch.index.codec.vectors.AbstractFlatVectorsFormat;
2829
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
2930
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryFlatVectorsScorer;
3031
import org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsReader;
@@ -85,20 +86,23 @@
8586
* <li>The sparse vector information, if required, mapping vector ordinal to doc ID
8687
* </ul>
8788
*/
88-
public class ES93BinaryQuantizedVectorsFormat extends ES93GenericFlatVectorsFormat {
89+
public class ES93BinaryQuantizedVectorsFormat extends AbstractFlatVectorsFormat {
8990

9091
public static final String NAME = "ES93BinaryQuantizedVectorsFormat";
9192

9293
private static final ES818BinaryFlatVectorsScorer scorer = new ES818BinaryFlatVectorsScorer(
9394
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
9495
);
9596

97+
private final ES93GenericFlatVectorsFormat rawFormat;
98+
9699
public ES93BinaryQuantizedVectorsFormat() {
97100
this(false, false);
98101
}
99102

100103
public ES93BinaryQuantizedVectorsFormat(boolean useBFloat16, boolean useDirectIO) {
101-
super(NAME, useBFloat16, useDirectIO);
104+
super(NAME);
105+
rawFormat = new ES93GenericFlatVectorsFormat(useBFloat16, useDirectIO);
102106
}
103107

104108
@Override
@@ -108,11 +112,16 @@ protected FlatVectorsScorer flatVectorsScorer() {
108112

109113
@Override
110114
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
111-
return new ES818BinaryQuantizedVectorsWriter(scorer, super.fieldsWriter(state), state);
115+
return new ES818BinaryQuantizedVectorsWriter(scorer, rawFormat.fieldsWriter(state), state);
112116
}
113117

114118
@Override
115119
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
116-
return new ES818BinaryQuantizedVectorsReader(state, super.fieldsReader(state), scorer);
120+
return new ES818BinaryQuantizedVectorsReader(state, rawFormat.fieldsReader(state), scorer);
121+
}
122+
123+
@Override
124+
public String toString() {
125+
return getName() + "(name=" + getName() + ", rawVectorFormat=" + rawFormat + ", scorer=" + scorer + ")";
117126
}
118127
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93GenericFlatVectorsFormat.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
1313
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
14+
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
1415
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
1516
import org.apache.lucene.index.SegmentReadState;
1617
import org.apache.lucene.index.SegmentWriteState;
@@ -20,8 +21,9 @@
2021
import java.io.IOException;
2122
import java.util.Map;
2223

23-
public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat {
24+
public class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFormat {
2425

26+
static final String NAME = "ES93GenericFlatVectorsFormat";
2527
static final String VECTOR_FORMAT_INFO_EXTENSION = "vfi";
2628
static final String META_CODEC_NAME = "ES93GenericFlatVectorsFormatMeta";
2729

@@ -35,12 +37,11 @@ public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFo
3537
VERSION_CURRENT
3638
);
3739

38-
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(
39-
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
40-
);
41-
private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(
42-
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
43-
);
40+
private static final FlatVectorsScorer scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
41+
42+
private static final DirectIOCapableFlatVectorsFormat float32VectorFormat = new DirectIOCapableLucene99FlatVectorsFormat(scorer);
43+
// TODO: a separate scorer for bfloat16
44+
private static final DirectIOCapableFlatVectorsFormat bfloat16VectorFormat = new ES93BFloat16FlatVectorsFormat(scorer);
4445

4546
private static final Map<String, DirectIOCapableFlatVectorsFormat> supportedFormats = Map.of(
4647
float32VectorFormat.getName(),
@@ -52,12 +53,21 @@ public abstract class ES93GenericFlatVectorsFormat extends AbstractFlatVectorsFo
5253
private final DirectIOCapableFlatVectorsFormat writeFormat;
5354
private final boolean useDirectIO;
5455

55-
public ES93GenericFlatVectorsFormat(String name, boolean useBFloat16, boolean useDirectIO) {
56-
super(name);
56+
public ES93GenericFlatVectorsFormat() {
57+
this(false, false);
58+
}
59+
60+
public ES93GenericFlatVectorsFormat(boolean useBFloat16, boolean useDirectIO) {
61+
super(NAME);
5762
writeFormat = useBFloat16 ? bfloat16VectorFormat : float32VectorFormat;
5863
this.useDirectIO = useDirectIO;
5964
}
6065

66+
@Override
67+
protected FlatVectorsScorer flatVectorsScorer() {
68+
return scorer;
69+
}
70+
6171
@Override
6272
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
6373
return new ES93GenericFlatVectorsWriter(META, writeFormat.getName(), useDirectIO, state, writeFormat.fieldsWriter(state));
@@ -74,6 +84,6 @@ public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException
7484

7585
@Override
7686
public String toString() {
77-
return getName() + "(name=" + getName() + ", writeFlatVectorFormat=" + writeFormat + ")";
87+
return getName() + "(name=" + getName() + ", format=" + writeFormat + ")";
7888
}
7989
}

server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,6 @@ org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat
99
org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat
1010
org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat
1111
org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat
12+
org.elasticsearch.index.codec.vectors.es93.ES93GenericFlatVectorsFormat
1213
org.elasticsearch.index.codec.vectors.es93.ES93BinaryQuantizedVectorsFormat
1314
org.elasticsearch.index.codec.vectors.es93.ES93HnswBinaryQuantizedVectorsFormat

server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93BinaryQuantizedVectorsFormatTests.java

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,9 @@
6262
import java.util.ArrayList;
6363
import java.util.Arrays;
6464
import java.util.List;
65-
import java.util.Locale;
6665

67-
import static java.lang.String.format;
6866
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
69-
import static org.hamcrest.Matchers.either;
70-
import static org.hamcrest.Matchers.startsWith;
67+
import static org.hamcrest.Matchers.oneOf;
7168

7269
public class ES93BinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
7370

@@ -196,11 +193,12 @@ public KnnVectorsFormat knnVectorsFormat() {
196193
}
197194
};
198195
String expectedPattern = "ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat,"
199-
+ " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat,"
200-
+ " flatVectorScorer=%s())";
201-
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
202-
var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
203-
assertThat(customCodec.knnVectorsFormat().toString(), either(startsWith(defaultScorer)).or(startsWith(memSegScorer)));
196+
+ " rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat,"
197+
+ " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}())),"
198+
+ " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}()))";
199+
var defaultScorer = expectedPattern.replaceAll("\\{}", "DefaultFlatVectorScorer");
200+
var memSegScorer = expectedPattern.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer");
201+
assertThat(customCodec.knnVectorsFormat().toString(), oneOf(defaultScorer, memSegScorer));
204202
}
205203

206204
@Override

server/src/test/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswBinaryQuantizedVectorsFormatTests.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,13 @@
5050

5151
import java.io.IOException;
5252
import java.util.Arrays;
53-
import java.util.Locale;
5453

5554
import static java.lang.String.format;
5655
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
5756
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
5857
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
5958
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
60-
import static org.hamcrest.Matchers.either;
61-
import static org.hamcrest.Matchers.startsWith;
59+
import static org.hamcrest.Matchers.oneOf;
6260

6361
public class ES93HnswBinaryQuantizedVectorsFormatTests extends BaseKnnVectorsFormatTestCase {
6462

@@ -91,14 +89,16 @@ public KnnVectorsFormat knnVectorsFormat() {
9189
return new ES93HnswBinaryQuantizedVectorsFormat(10, 20, false, false, 1, null);
9290
}
9391
};
94-
String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat, maxConn=10, beamWidth=20,"
92+
String expectedPattern = "ES93HnswBinaryQuantizedVectorsFormat(name=ES93HnswBinaryQuantizedVectorsFormat,"
93+
+ " maxConn=10, beamWidth=20,"
9594
+ " flatVectorFormat=ES93BinaryQuantizedVectorsFormat(name=ES93BinaryQuantizedVectorsFormat,"
96-
+ " writeFlatVectorFormat=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat,"
97-
+ " flatVectorScorer=%s())";
95+
+ " rawVectorFormat=ES93GenericFlatVectorsFormat(name=ES93GenericFlatVectorsFormat,"
96+
+ " format=Lucene99FlatVectorsFormat(name=Lucene99FlatVectorsFormat, flatVectorScorer={}())),"
97+
+ " scorer=ES818BinaryFlatVectorsScorer(nonQuantizedDelegate={}())))";
9898

99-
var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer");
100-
var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer");
101-
assertThat(customCodec.knnVectorsFormat().toString(), either(startsWith(defaultScorer)).or(startsWith(memSegScorer)));
99+
var defaultScorer = expectedPattern.replaceAll("\\{}", "DefaultFlatVectorScorer");
100+
var memSegScorer = expectedPattern.replaceAll("\\{}", "Lucene99MemorySegmentFlatVectorsScorer");
101+
assertThat(customCodec.knnVectorsFormat().toString(), oneOf(defaultScorer, memSegScorer));
102102
}
103103

104104
public void testSingleVectorCase() throws Exception {

0 commit comments

Comments
 (0)