Skip to content

Commit 8d25046

Browse files
committed
iter
1 parent 470245b commit 8d25046

File tree

7 files changed

+183
-193
lines changed

7 files changed

+183
-193
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/AssignmentArraySorter.java

Lines changed: 66 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,70 +13,70 @@
1313

1414
class AssignmentArraySorter extends TimSorter {
1515

16-
private final float[][] centroids;
17-
private final float[][] tmpC;
18-
19-
private final int[] centroidsOrds;
20-
private final int[] tmpA;
21-
22-
private final int[] sortOrdering;
23-
private final int[] tmpSort;
24-
25-
AssignmentArraySorter(float[][] centroids, int[] centroidsOrds, int[] sortOrdering) {
26-
super(centroids.length / 64);
27-
this.centroids = centroids;
28-
this.centroidsOrds = centroidsOrds;
29-
this.sortOrdering = sortOrdering;
30-
31-
int maxTempSlots = centroids.length / 64;
32-
this.tmpC = new float[maxTempSlots][];
33-
this.tmpA = new int[maxTempSlots];
34-
this.tmpSort = new int[maxTempSlots];
35-
}
36-
37-
@Override
38-
protected int compare(int i, int j) {
39-
return Integer.compare(sortOrdering[i], sortOrdering[j]);
40-
}
41-
42-
@Override
43-
protected void swap(int i, int j) {
44-
final float[] tmpC = centroids[i];
45-
centroids[i] = centroids[j];
46-
centroids[j] = tmpC;
47-
48-
final int tmpA = centroidsOrds[i];
49-
centroidsOrds[i] = centroidsOrds[j];
50-
centroidsOrds[j] = tmpA;
51-
52-
final int tmpSort = sortOrdering[i];
53-
sortOrdering[i] = sortOrdering[j];
54-
sortOrdering[j] = tmpSort;
55-
}
56-
57-
@Override
58-
protected void copy(int src, int dest) {
59-
centroids[dest] = centroids[src];
60-
centroidsOrds[dest] = centroidsOrds[src];
61-
sortOrdering[dest] = sortOrdering[src];
62-
}
63-
64-
@Override
65-
protected void save(int start, int len) {
66-
System.arraycopy(centroids, start, tmpC, 0, len);
67-
System.arraycopy(centroidsOrds, start, tmpA, 0, len);
68-
System.arraycopy(sortOrdering, start, tmpSort, 0, len);
69-
}
70-
71-
@Override
72-
protected void restore(int src, int dest) {
73-
centroids[dest] = tmpC[src];
74-
centroidsOrds[dest] = tmpA[src];
75-
sortOrdering[dest] = tmpSort[src];
76-
}
77-
78-
@Override
79-
protected int compareSaved(int i, int j) {
80-
return Integer.compare(tmpSort[i], sortOrdering[j]);
81-
}
16+
private final float[][] centroids;
17+
private final float[][] tmpC;
18+
19+
private final int[] centroidsOrds;
20+
private final int[] tmpA;
21+
22+
private final int[] sortOrdering;
23+
private final int[] tmpSort;
24+
25+
AssignmentArraySorter(float[][] centroids, int[] centroidsOrds, int[] sortOrdering) {
26+
super(centroids.length / 64);
27+
this.centroids = centroids;
28+
this.centroidsOrds = centroidsOrds;
29+
this.sortOrdering = sortOrdering;
30+
31+
int maxTempSlots = centroids.length / 64;
32+
this.tmpC = new float[maxTempSlots][];
33+
this.tmpA = new int[maxTempSlots];
34+
this.tmpSort = new int[maxTempSlots];
35+
}
36+
37+
@Override
38+
protected int compare(int i, int j) {
39+
return Integer.compare(sortOrdering[i], sortOrdering[j]);
40+
}
41+
42+
@Override
43+
protected void swap(int i, int j) {
44+
final float[] tmpC = centroids[i];
45+
centroids[i] = centroids[j];
46+
centroids[j] = tmpC;
47+
48+
final int tmpA = centroidsOrds[i];
49+
centroidsOrds[i] = centroidsOrds[j];
50+
centroidsOrds[j] = tmpA;
51+
52+
final int tmpSort = sortOrdering[i];
53+
sortOrdering[i] = sortOrdering[j];
54+
sortOrdering[j] = tmpSort;
55+
}
56+
57+
@Override
58+
protected void copy(int src, int dest) {
59+
centroids[dest] = centroids[src];
60+
centroidsOrds[dest] = centroidsOrds[src];
61+
sortOrdering[dest] = sortOrdering[src];
62+
}
63+
64+
@Override
65+
protected void save(int start, int len) {
66+
System.arraycopy(centroids, start, tmpC, 0, len);
67+
System.arraycopy(centroidsOrds, start, tmpA, 0, len);
68+
System.arraycopy(sortOrdering, start, tmpSort, 0, len);
69+
}
70+
71+
@Override
72+
protected void restore(int src, int dest) {
73+
centroids[dest] = tmpC[src];
74+
centroidsOrds[dest] = tmpA[src];
75+
sortOrdering[dest] = tmpSort[src];
76+
}
77+
78+
@Override
79+
protected int compareSaved(int i, int j) {
80+
return Integer.compare(tmpSort[i], sortOrdering[j]);
81+
}
8282
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,13 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
4747
}
4848

4949
@Override
50-
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numParentCentroids,
51-
int numCentroids, IndexInput centroids, float[] targetQuery)
52-
throws IOException {
50+
CentroidQueryScorer getCentroidScorer(
51+
FieldInfo fieldInfo,
52+
int numParentCentroids,
53+
int numCentroids,
54+
IndexInput centroids,
55+
float[] targetQuery
56+
) throws IOException {
5357
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
5458
final float globalCentroidDp = fieldEntry.globalCentroidDp();
5559
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
@@ -153,12 +157,8 @@ private float int4QuantizedScore(
153157

154158
// FIXME: clean up duplicative code between the scorers
155159
@Override
156-
CentroidQueryScorerWChildren getCentroidScorerWChildren(
157-
FieldInfo fieldInfo,
158-
int numCentroids,
159-
IndexInput centroids,
160-
float[] targetQuery
161-
) throws IOException {
160+
ParentCentroidQueryScorer getParentCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
161+
throws IOException {
162162
FieldEntry fieldEntry = fields.get(fieldInfo.number);
163163
float[] globalCentroid = fieldEntry.globalCentroid();
164164
float globalCentroidDp = fieldEntry.globalCentroidDp();
@@ -172,7 +172,7 @@ CentroidQueryScorerWChildren getCentroidScorerWChildren(
172172
globalCentroid
173173
);
174174
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
175-
return new CentroidQueryScorerWChildren() {
175+
return new ParentCentroidQueryScorer() {
176176
int currentCentroid = -1;
177177
private final float[] centroidCorrectiveValues = new float[3];
178178
private final long quantizedVectorByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES;

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 49 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
import java.nio.ByteBuffer;
2929
import java.nio.ByteOrder;
3030
import java.util.ArrayList;
31-
import java.util.Arrays;
3231
import java.util.List;
3332

3433
import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS;
3534
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize;
3635
import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary;
36+
import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.DEFAULT_VECTORS_PER_CLUSTER;
3737

3838
/**
3939
* Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to
@@ -167,14 +167,23 @@ private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput,
167167
}
168168

169169
@Override
170-
CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numParentCentroids,
171-
int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) {
170+
CentroidSupplier createCentroidSupplier(
171+
IndexInput centroidsInput,
172+
int numParentCentroids,
173+
int numCentroids,
174+
FieldInfo fieldInfo,
175+
float[] globalCentroid
176+
) {
172177
return new OffHeapCentroidSupplier(centroidsInput, numParentCentroids, numCentroids, fieldInfo);
173178
}
174179

175-
static void writeCentroidsAndPartitions(List<CentroidPartition> centroidPartitions, float[][] centroids,
176-
FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
177-
throws IOException {
180+
static void writeCentroidsAndPartitions(
181+
List<CentroidPartition> centroidPartitions,
182+
float[][] centroids,
183+
FieldInfo fieldInfo,
184+
float[] globalCentroid,
185+
IndexOutput centroidOutput
186+
) throws IOException {
178187
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
179188
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
180189
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
@@ -208,7 +217,7 @@ static void writeCentroidsAndPartitions(List<CentroidPartition> centroidPartitio
208217
writeQuantizedValue(centroidOutput, quantizedScratch, result);
209218
}
210219

211-
//write the raw float vectors so we can quantize the query vector relative to the centroid on read
220+
// write the raw float vectors so we can quantize the query vector relative to the centroid on read
212221
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
213222
for (float[] centroid : centroids) {
214223
buffer.asFloatBuffer().put(centroid);
@@ -281,41 +290,44 @@ CentroidAssignments calculateAndWriteCentroids(
281290
// TODO: sort while constructing the hkmeans structure
282291
// we do this so we don't have to sort the assignments which is much more expensive
283292
int[] centroidOrds = new int[centroids.length];
284-
for(int i = 0; i < centroidOrds.length; i++) {
293+
for (int i = 0; i < centroidOrds.length; i++) {
285294
centroidOrds[i] = i;
286295
}
287296

288-
// TODO: sort by global centroids as well
289-
// TODO: have this take a function instead of just an int[] for sorting
290-
AssignmentArraySorter sorter = new AssignmentArraySorter(centroids, centroidOrds, kMeansResult.parentLayer());
291-
sorter.sort(0, centroids.length);
292-
293297
List<CentroidPartition> centroidPartitions = new ArrayList<>();
294-
for(int i = 0; i < kMeansResult.parentLayer().length;) {
295-
// for any layer that was not partitioned we treat it duplicatively as a parent and child
296-
if(kMeansResult.parentLayer()[i] == -1) {
297-
centroidPartitions.add(new CentroidPartition(centroids[i], i, 1));
298-
i++;
299-
} else {
300-
int label = kMeansResult.parentLayer()[i];
301-
int totalCentroids = 0;
302-
float[] parentPartitionCentroid = new float[fieldInfo.getVectorDimension()];
303-
int j = i;
304-
for (; j < kMeansResult.parentLayer().length; j++) {
305-
if(kMeansResult.parentLayer()[j] != label) {
306-
break;
298+
299+
if (centroids.length > DEFAULT_VECTORS_PER_CLUSTER) {
300+
// TODO: sort by global centroids as well
301+
// TODO: have this take a function instead of just an int[] for sorting
302+
AssignmentArraySorter sorter = new AssignmentArraySorter(centroids, centroidOrds, kMeansResult.parentLayer());
303+
sorter.sort(0, centroids.length);
304+
305+
for (int i = 0; i < kMeansResult.parentLayer().length;) {
306+
// for any layer that was not partitioned we treat it duplicatively as a parent and child
307+
if (kMeansResult.parentLayer()[i] == -1) {
308+
centroidPartitions.add(new CentroidPartition(centroids[i], i, 1));
309+
i++;
310+
} else {
311+
int label = kMeansResult.parentLayer()[i];
312+
int totalCentroids = 0;
313+
float[] parentPartitionCentroid = new float[fieldInfo.getVectorDimension()];
314+
int j = i;
315+
for (; j < kMeansResult.parentLayer().length; j++) {
316+
if (kMeansResult.parentLayer()[j] != label) {
317+
break;
318+
}
319+
for (int k = 0; k < parentPartitionCentroid.length; k++) {
320+
parentPartitionCentroid[k] += centroids[i][k];
321+
}
322+
totalCentroids++;
307323
}
308-
for (int k = 0; k < parentPartitionCentroid.length; k++) {
309-
parentPartitionCentroid[k] += centroids[i][k];
324+
int childOrdinal = i;
325+
i = j;
326+
for (int d = 0; d < parentPartitionCentroid.length; d++) {
327+
parentPartitionCentroid[d] /= totalCentroids;
310328
}
311-
totalCentroids++;
312-
}
313-
int childOrdinal = i;
314-
i = j;
315-
for (int d = 0; d < parentPartitionCentroid.length; d++) {
316-
parentPartitionCentroid[d] /= totalCentroids;
329+
centroidPartitions.add(new CentroidPartition(parentPartitionCentroid, childOrdinal, totalCentroids));
317330
}
318-
centroidPartitions.add(new CentroidPartition(parentPartitionCentroid, childOrdinal, totalCentroids));
319331
}
320332
}
321333

@@ -330,7 +342,7 @@ CentroidAssignments calculateAndWriteCentroids(
330342
for (int c = 0; c < assignmentsByCluster.length; c++) {
331343
IntArrayList cluster = new IntArrayList(vectorPerCluster);
332344
for (int j = 0; j < assignments.length; j++) {
333-
if(assignments[j] == -1) {
345+
if (assignments[j] == -1) {
334346
continue;
335347
}
336348
if (assignments[j] == centroidOrds[c]) {
@@ -339,7 +351,7 @@ CentroidAssignments calculateAndWriteCentroids(
339351
}
340352

341353
for (int j = 0; j < soarAssignments.length; j++) {
342-
if(soarAssignments[j] == -1) {
354+
if (soarAssignments[j] == -1) {
343355
continue;
344356
}
345357
if (soarAssignments[j] == centroidOrds[c]) {

0 commit comments

Comments
 (0)