Skip to content

Commit 44ecd39

Browse files
committed
Remove intermediate class
1 parent 115507a commit 44ecd39

File tree

3 files changed

+51
-95
lines changed

3 files changed

+51
-95
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93HnswScalarQuantizedVectorsFormat.java

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111

1212
import org.apache.lucene.codecs.KnnVectorsReader;
1313
import org.apache.lucene.codecs.KnnVectorsWriter;
14+
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
1415
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
16+
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer;
1517
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat;
18+
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsReader;
19+
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsWriter;
1620
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
1721
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
1822
import org.apache.lucene.index.SegmentReadState;
@@ -26,16 +30,17 @@ public class ES93HnswScalarQuantizedVectorsFormat extends AbstractHnswVectorsFor
2630

2731
static final String NAME = "ES93HnswScalarQuantizedVectorsFormat";
2832

29-
/** The format for storing, reading, merging vectors on disk */
30-
private final FlatVectorsFormat flatVectorsFormat;
33+
static final Lucene104ScalarQuantizedVectorScorer flatVectorScorer = new Lucene104ScalarQuantizedVectorScorer(
34+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
35+
);
36+
37+
private final Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding encoding;
38+
private final FlatVectorsFormat rawVectorFormat;
3139

3240
public ES93HnswScalarQuantizedVectorsFormat() {
3341
super(NAME);
34-
this.flatVectorsFormat = new ES93ScalarQuantizedFlatVectorsFormat(
35-
Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT,
36-
ES93GenericFlatVectorsFormat.ElementType.STANDARD,
37-
false
38-
);
42+
this.encoding = Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT;
43+
this.rawVectorFormat = new ES93GenericFlatVectorsFormat(ES93GenericFlatVectorsFormat.ElementType.STANDARD, false);
3944
}
4045

4146
public ES93HnswScalarQuantizedVectorsFormat(
@@ -46,7 +51,8 @@ public ES93HnswScalarQuantizedVectorsFormat(
4651
boolean useDirectIO
4752
) {
4853
super(NAME, maxConn, beamWidth);
49-
this.flatVectorsFormat = new ES93ScalarQuantizedFlatVectorsFormat(encoding, elementType, useDirectIO);
54+
this.encoding = encoding;
55+
this.rawVectorFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO);
5056
}
5157

5258
public ES93HnswScalarQuantizedVectorsFormat(
@@ -59,12 +65,13 @@ public ES93HnswScalarQuantizedVectorsFormat(
5965
ExecutorService mergeExec
6066
) {
6167
super(NAME, maxConn, beamWidth, numMergeWorkers, mergeExec);
62-
this.flatVectorsFormat = new ES93ScalarQuantizedFlatVectorsFormat(encoding, elementType, useDirectIO);
68+
this.encoding = encoding;
69+
this.rawVectorFormat = new ES93GenericFlatVectorsFormat(elementType, useDirectIO);
6370
}
6471

6572
@Override
6673
protected FlatVectorsFormat flatVectorsFormat() {
67-
return flatVectorsFormat;
74+
return rawVectorFormat;
6875
}
6976

7077
@Override
@@ -73,7 +80,8 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
7380
state,
7481
maxConn,
7582
beamWidth,
76-
flatVectorsFormat.fieldsWriter(state),
83+
new Lucene104ScalarQuantizedVectorsWriter(state, encoding, rawVectorFormat.fieldsWriter(state), flatVectorScorer) {
84+
},
7785
numMergeWorkers,
7886
mergeExec,
7987
0
@@ -82,6 +90,9 @@ public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException
8290

8391
@Override
8492
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
85-
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
93+
return new Lucene99HnswVectorsReader(
94+
state,
95+
new Lucene104ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), flatVectorScorer)
96+
);
8697
}
8798
}

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedFlatVectorsFormat.java

Lines changed: 0 additions & 78 deletions
This file was deleted.

server/src/main/java/org/elasticsearch/index/codec/vectors/es93/ES93ScalarQuantizedVectorsFormat.java

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
import org.apache.lucene.codecs.KnnVectorsFormat;
1313
import org.apache.lucene.codecs.KnnVectorsReader;
1414
import org.apache.lucene.codecs.KnnVectorsWriter;
15+
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
1516
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
1617
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
18+
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorScorer;
1719
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsFormat;
20+
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsReader;
21+
import org.apache.lucene.codecs.lucene104.Lucene104ScalarQuantizedVectorsWriter;
1822
import org.apache.lucene.index.ByteVectorValues;
1923
import org.apache.lucene.index.FieldInfo;
2024
import org.apache.lucene.index.FloatVectorValues;
@@ -35,7 +39,12 @@ public class ES93ScalarQuantizedVectorsFormat extends KnnVectorsFormat {
3539

3640
static final String NAME = "ES93ScalarQuantizedVectorsFormat";
3741

38-
private final FlatVectorsFormat format;
42+
static final Lucene104ScalarQuantizedVectorScorer flatVectorScorer = new Lucene104ScalarQuantizedVectorScorer(
43+
FlatVectorScorerUtil.getLucene99FlatVectorsScorer()
44+
);
45+
46+
private final Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding encoding;
47+
private final FlatVectorsFormat rawVectorFormat;
3948

4049
public ES93ScalarQuantizedVectorsFormat() {
4150
this(ES93GenericFlatVectorsFormat.ElementType.STANDARD, Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding.SEVEN_BIT);
@@ -50,17 +59,22 @@ public ES93ScalarQuantizedVectorsFormat(
5059
Lucene104ScalarQuantizedVectorsFormat.ScalarEncoding encoding
5160
) {
5261
super(NAME);
53-
this.format = new ES93ScalarQuantizedFlatVectorsFormat(encoding, elementType, false);
62+
assert elementType != ES93GenericFlatVectorsFormat.ElementType.BIT : "BIT should not be used with scalar quantization";
63+
this.encoding = encoding;
64+
this.rawVectorFormat = new ES93GenericFlatVectorsFormat(elementType, false);
5465
}
5566

5667
@Override
5768
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
58-
return format.fieldsWriter(state);
69+
return new Lucene104ScalarQuantizedVectorsWriter(state, encoding, rawVectorFormat.fieldsWriter(state), flatVectorScorer) {
70+
};
5971
}
6072

6173
@Override
6274
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
63-
return new ES93FlatVectorsReader(format.fieldsReader(state));
75+
return new ES93FlatVectorsReader(
76+
new Lucene104ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state), flatVectorScorer)
77+
);
6478
}
6579

6680
@Override
@@ -70,7 +84,16 @@ public int getMaxDimensions(String fieldName) {
7084

7185
@Override
7286
public String toString() {
73-
return NAME + "(name=" + NAME + ", innerFormat=" + format + ")";
87+
return NAME
88+
+ "(name="
89+
+ NAME
90+
+ ", encoding="
91+
+ encoding
92+
+ ", flatVectorScorer="
93+
+ flatVectorScorer
94+
+ ", rawVectorFormat="
95+
+ rawVectorFormat
96+
+ ")";
7497
}
7598

7699
public static class ES93FlatVectorsReader extends KnnVectorsReader {

0 commit comments

Comments
 (0)