Skip to content

Commit abfb265

Browse files
authored
[DiskBBQ] Introduce KmeansFloatVectorValues to reduce the number of FloatVectorValues implementations (elastic#139097)
1 parent 733eee4 commit abfb265

File tree

5 files changed

+218
-122
lines changed

5 files changed

+218
-122
lines changed
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.index.codec.vectors.cluster;
11+
12+
import org.apache.lucene.index.FloatVectorValues;
13+
import org.apache.lucene.store.IndexInput;
14+
import org.apache.lucene.store.RandomAccessInput;
15+
16+
import java.io.IOException;
17+
import java.io.UncheckedIOException;
18+
import java.util.List;
19+
20+
/**
21+
* Unified class that can represent on-heap and off-heap vector values.
22+
*/
23+
public final class KmeansFloatVectorValues extends FloatVectorValues {
24+
25+
private final VectorSupplier vectors;
26+
private final DocSupplier docs;
27+
private final int numVectors;
28+
29+
private KmeansFloatVectorValues(VectorSupplier vectors, DocSupplier docs, int numVectors) {
30+
this.vectors = vectors;
31+
this.docs = docs;
32+
this.numVectors = numVectors;
33+
}
34+
35+
/**
36+
* Build an instance from on-heap data structures.
37+
*/
38+
public static KmeansFloatVectorValues build(List<float[]> vectors, int[] docs, int dim) {
39+
VectorSupplier vectorSupplier = new OnHeapVectorSupplier(vectors, dim);
40+
DocSupplier docSupplier = docs == null ? null : new OnHeapDocSupplier(docs);
41+
return new KmeansFloatVectorValues(vectorSupplier, docSupplier, vectors.size());
42+
}
43+
44+
/**
45+
* Builds an instance from off-heap data structures. Vectors are expected to be written as
46+
* little endian floats one after the other. Docs are expected to be written as little endian ints
47+
* one after the other.
48+
*/
49+
public static KmeansFloatVectorValues build(IndexInput vectors, IndexInput docs, int numVectors, int dims) throws IOException {
50+
long vectorLength = (long) dims * Float.BYTES;
51+
float[] vector = new float[dims];
52+
VectorSupplier vectorSupplier = new OffHeapVectorSupplier(vectors, vector, vectorLength);
53+
DocSupplier docSupplier;
54+
if (docs == null) {
55+
docSupplier = null;
56+
} else {
57+
RandomAccessInput randomDocs = docs.randomAccessSlice(0, docs.length());
58+
docSupplier = new OffHeapDocSupplier(docs, randomDocs);
59+
}
60+
return new KmeansFloatVectorValues(vectorSupplier, docSupplier, numVectors);
61+
}
62+
63+
@Override
64+
public float[] vectorValue(int ord) throws IOException {
65+
return vectors.vector(ord);
66+
}
67+
68+
@Override
69+
public FloatVectorValues copy() {
70+
return new KmeansFloatVectorValues(vectors.copy(), docs != null ? docs.copy() : null, numVectors);
71+
}
72+
73+
@Override
74+
public int dimension() {
75+
return vectors.dims();
76+
}
77+
78+
@Override
79+
public int size() {
80+
return numVectors;
81+
}
82+
83+
@Override
84+
public int ordToDoc(int ord) {
85+
if (docs == null) {
86+
return ord;
87+
}
88+
return docs.ordToDoc(ord);
89+
}
90+
91+
private sealed interface VectorSupplier permits OffHeapVectorSupplier, OnHeapVectorSupplier {
92+
93+
float[] vector(int ord) throws IOException;
94+
95+
int dims();
96+
97+
VectorSupplier copy();
98+
}
99+
100+
private record OnHeapVectorSupplier(List<float[]> vectors, int dims) implements VectorSupplier {
101+
102+
@Override
103+
public float[] vector(int ord) {
104+
return vectors.get(ord);
105+
}
106+
107+
@Override
108+
public int dims() {
109+
return dims;
110+
}
111+
112+
@Override
113+
public VectorSupplier copy() {
114+
return this;
115+
}
116+
}
117+
118+
private record OffHeapVectorSupplier(IndexInput vectors, float[] vector, long vectorLength) implements VectorSupplier {
119+
120+
@Override
121+
public float[] vector(int ord) throws IOException {
122+
vectors.seek(ord * vectorLength);
123+
vectors.readFloats(vector, 0, vector.length);
124+
return vector;
125+
}
126+
127+
@Override
128+
public int dims() {
129+
return vector.length;
130+
}
131+
132+
@Override
133+
public VectorSupplier copy() {
134+
return new OffHeapVectorSupplier(vectors.clone(), vector.clone(), vectorLength);
135+
}
136+
}
137+
138+
private sealed interface DocSupplier permits OnHeapDocSupplier, OffHeapDocSupplier {
139+
int ordToDoc(int ord);
140+
141+
DocSupplier copy();
142+
}
143+
144+
private record OnHeapDocSupplier(int[] docs) implements DocSupplier {
145+
@Override
146+
public int ordToDoc(int ord) {
147+
return docs[ord];
148+
}
149+
150+
@Override
151+
public DocSupplier copy() {
152+
return this;
153+
}
154+
}
155+
156+
private record OffHeapDocSupplier(IndexInput docs, RandomAccessInput randomDocs) implements DocSupplier {
157+
@Override
158+
public int ordToDoc(int ord) {
159+
try {
160+
return randomDocs.readInt((long) ord * Integer.BYTES);
161+
} catch (IOException e) {
162+
throw new UncheckedIOException(e);
163+
}
164+
}
165+
166+
@Override
167+
public DocSupplier copy() {
168+
IndexInput docsCopy = docs.clone();
169+
try {
170+
RandomAccessInput randomDocsCopy = docsCopy.randomAccessSlice(0, docsCopy.length());
171+
return new OffHeapDocSupplier(docsCopy, randomDocsCopy);
172+
} catch (IOException e) {
173+
throw new UncheckedIOException(e);
174+
}
175+
}
176+
}
177+
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidSupplier.java

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,24 @@
99

1010
package org.elasticsearch.index.codec.vectors.diskbbq;
1111

12+
import org.apache.lucene.index.FloatVectorValues;
13+
import org.elasticsearch.index.codec.vectors.cluster.KmeansFloatVectorValues;
14+
1215
import java.io.IOException;
16+
import java.util.Arrays;
1317

1418
/**
1519
* An interface for that supply centroids.
1620
*/
1721
public interface CentroidSupplier {
18-
CentroidSupplier EMPTY = new CentroidSupplier() {
19-
@Override
20-
public int size() {
21-
return 0;
22-
}
23-
24-
@Override
25-
public float[] centroid(int centroidOrdinal) {
26-
throw new IllegalStateException("No centroids");
27-
}
28-
};
2922

3023
int size();
3124

3225
float[] centroid(int centroidOrdinal) throws IOException;
3326

34-
static CentroidSupplier fromArray(float[][] centroids) {
35-
if (centroids.length == 0) {
36-
return EMPTY;
37-
}
27+
FloatVectorValues asFloatVectorValues() throws IOException;
28+
29+
static CentroidSupplier fromArray(float[][] centroids, int dims) {
3830
return new CentroidSupplier() {
3931
@Override
4032
public int size() {
@@ -45,6 +37,11 @@ public int size() {
4537
public float[] centroid(int centroidOrdinal) {
4638
return centroids[centroidOrdinal];
4739
}
40+
41+
@Override
42+
public FloatVectorValues asFloatVectorValues() {
43+
return KmeansFloatVectorValues.build(Arrays.asList(centroids), null, dims);
44+
}
4845
};
4946
}
5047
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,16 @@
2626
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
2727
import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans;
2828
import org.elasticsearch.index.codec.vectors.cluster.KMeansResult;
29+
import org.elasticsearch.index.codec.vectors.cluster.KmeansFloatVectorValues;
2930
import org.elasticsearch.logging.LogManager;
3031
import org.elasticsearch.logging.Logger;
3132
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
3233
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
3334
import org.elasticsearch.simdvec.ESVectorUtil;
3435

3536
import java.io.IOException;
36-
import java.io.UncheckedIOException;
3737
import java.nio.ByteBuffer;
3838
import java.nio.ByteOrder;
39-
import java.util.AbstractList;
4039
import java.util.Arrays;
4140

4241
import static org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans.NO_SOAR_ASSIGNMENT;
@@ -415,7 +414,7 @@ private void writeCentroidsWithParents(
415414
centroidOutput.writeVInt(centroidGroups.centroids.length);
416415
centroidOutput.writeVInt(centroidGroups.maxVectorsPerCentroidLength);
417416
QuantizedCentroids parentQuantizeCentroid = new QuantizedCentroids(
418-
CentroidSupplier.fromArray(centroidGroups.centroids),
417+
CentroidSupplier.fromArray(centroidGroups.centroids, fieldInfo.getVectorDimension()),
419418
fieldInfo.getVectorDimension(),
420419
osq,
421420
globalCentroid
@@ -476,21 +475,7 @@ private void writeCentroidsWithoutParents(
476475
private record CentroidGroups(float[][] centroids, int[][] vectors, int maxVectorsPerCentroidLength) {}
477476

478477
private CentroidGroups buildCentroidGroups(FieldInfo fieldInfo, CentroidSupplier centroidSupplier) throws IOException {
479-
final FloatVectorValues floatVectorValues = FloatVectorValues.fromFloats(new AbstractList<>() {
480-
@Override
481-
public float[] get(int index) {
482-
try {
483-
return centroidSupplier.centroid(index);
484-
} catch (IOException e) {
485-
throw new UncheckedIOException(e);
486-
}
487-
}
488-
489-
@Override
490-
public int size() {
491-
return centroidSupplier.size();
492-
}
493-
}, fieldInfo.getVectorDimension());
478+
final FloatVectorValues floatVectorValues = centroidSupplier.asFloatVectorValues();
494479
// we use the HierarchicalKMeans to partition the space of all vectors across merging segments
495480
// this are small numbers so we run it wih all the centroids.
496481
final KMeansResult kMeansResult = new HierarchicalKMeans(
@@ -587,6 +572,11 @@ public float[] centroid(int centroidOrdinal) throws IOException {
587572
this.currOrd = centroidOrdinal;
588573
return scratch;
589574
}
575+
576+
@Override
577+
public FloatVectorValues asFloatVectorValues() throws IOException {
578+
return KmeansFloatVectorValues.build(centroidsInput, null, numCentroids, dimension);
579+
}
590580
}
591581

592582
static class QuantizedCentroids implements QuantizedVectorValues {

0 commit comments

Comments
 (0)