Skip to content

Commit b1f9ae4

Browse files
committed
migrated from short to int and fixed IOUtils copy/paste errors
1 parent 5ca53d3 commit b1f9ae4

File tree

8 files changed

+60
-65
lines changed

8 files changed

+60
-65
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99

1010
package org.elasticsearch.index.codec.vectors;
1111

12-
record CentroidAssignments(int numCentroids, float[][] cachedCentroids, short[] assignments, short[] soarAssignments) {
12+
record CentroidAssignments(int numCentroids, float[][] cachedCentroids, int[] assignments, int[] soarAssignments) {
1313

14-
CentroidAssignments(float[][] centroids, short[] assignments, short[] soarAssignments) {
14+
CentroidAssignments(float[][] centroids, int[] assignments, int[] soarAssignments) {
1515
this(centroids.length, centroids, assignments, soarAssignments);
1616
}
1717

18-
CentroidAssignments(int numCentroids, short[] assignments, short[] soarAssignments) {
18+
CentroidAssignments(int numCentroids, int[] assignments, int[] soarAssignments) {
1919
this(numCentroids, null, assignments, soarAssignments);
2020
}
2121
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ long[] buildAndWritePostingsLists(
6464
BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer);
6565
DocIdsWriter docIdsWriter = new DocIdsWriter();
6666

67-
short[] assignments = centroidAssignments.assignments();
68-
short[] soarAssignments = centroidAssignments.soarAssignments();
67+
int[] assignments = centroidAssignments.assignments();
68+
int[] soarAssignments = centroidAssignments.soarAssignments();
6969

7070
int[][] clustersForMetrics = null;
7171
if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) {
@@ -298,8 +298,8 @@ CentroidAssignments calculateAndWriteCentroids(
298298
// TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids
299299
KMeansResult kMeansResult = new HierarchicalKMeans().cluster(floatVectorValues, vectorPerCluster);
300300
float[][] centroids = kMeansResult.centroids();
301-
short[] assignments = kMeansResult.assignments();
302-
short[] soarAssignments = kMeansResult.soarAssignments();
301+
int[] assignments = kMeansResult.assignments();
302+
int[] soarAssignments = kMeansResult.soarAssignments();
303303

304304
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
305305
// preliminary tests suggest recall is good using only centroids but need to do further evaluation

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
292292
success = true;
293293
} finally {
294294
if (success == false && centroidTempName != null) {
295-
org.apache.lucene.util.IOUtils.closeWhileHandlingException(centroidTemp);
295+
IOUtils.closeWhileHandlingException(centroidTemp);
296296
org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName);
297297
}
298298
}
@@ -301,11 +301,11 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
301301
centroidOffset = ivfCentroids.getFilePointer();
302302
writeMeta(fieldInfo, centroidOffset, 0, new long[0], null);
303303
CodecUtil.writeFooter(centroidTemp);
304-
org.apache.lucene.util.IOUtils.close(centroidTemp);
304+
IOUtils.close(centroidTemp);
305305
return;
306306
}
307307
CodecUtil.writeFooter(centroidTemp);
308-
org.apache.lucene.util.IOUtils.close(centroidTemp);
308+
IOUtils.close(centroidTemp);
309309
centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES);
310310
try (IndexInput centroidsInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) {
311311
ivfCentroids.copyBytes(centroidsInput, centroidsInput.length() - CodecUtil.footerLength());

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@ public class HierarchicalKMeans {
2525

2626
final int maxIterations;
2727
final int samplesPerCluster;
28-
final short clustersPerNeighborhood;
28+
final int clustersPerNeighborhood;
2929

3030
public HierarchicalKMeans() {
31-
this(MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, (short) MAXK);
31+
this(MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK);
3232
}
3333

34-
HierarchicalKMeans(int maxIterations, int samplesPerCluster, short clustersPerNeighborhood) {
34+
HierarchicalKMeans(int maxIterations, int samplesPerCluster, int clustersPerNeighborhood) {
3535
this.maxIterations = maxIterations;
3636
this.samplesPerCluster = samplesPerCluster;
3737
this.clustersPerNeighborhood = clustersPerNeighborhood;
@@ -56,7 +56,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
5656
if (vectors.size() < targetSize) {
5757
float[] centroid = new float[vectors.dimension()];
5858
System.arraycopy(vectors.vectorValue(0), 0, centroid, 0, vectors.dimension());
59-
return new KMeansResult(new float[][] { centroid }, new short[vectors.size()]);
59+
return new KMeansResult(new float[][] { centroid }, new int[vectors.size()]);
6060
}
6161

6262
// partition the space
@@ -80,7 +80,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
8080
int m = Math.min(k * samplesPerCluster, vectors.size());
8181

8282
// TODO: instead of creating a sub-cluster assignments reuse the parent array each time
83-
short[] assignments = new short[vectors.size()];
83+
int[] assignments = new int[vectors.size()];
8484

8585
KMeans kmeans = new KMeans(m, maxIterations);
8686
float[][] centroids = KMeans.pickInitialCentroids(vectors, m, k);
@@ -95,9 +95,9 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
9595
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
9696
for (int i = 0; i < vectors.size(); i++) {
9797
float smallest = Float.MAX_VALUE;
98-
short centroidIdx = -1;
98+
int centroidIdx = -1;
9999
float[] vector = vectors.vectorValue(i);
100-
for (short j = 0; j < centroids.length; j++) {
100+
for (int j = 0; j < centroids.length; j++) {
101101
float[] centroid = centroids[j];
102102
float d = VectorUtil.squareDistance(vector, centroid);
103103
if (d < smallest) {
@@ -122,7 +122,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
122122
}
123123
}
124124

125-
short effectiveK = 0;
125+
int effectiveK = 0;
126126
for (int i = 0; i < clusterSizes.length; i++) {
127127
if (clusterSizes[i] > 0) {
128128
effectiveK++;
@@ -138,7 +138,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
138138
return kMeansResult;
139139
}
140140

141-
for (short c = 0; c < clusterSizes.length; c++) {
141+
for (int c = 0; c < clusterSizes.length; c++) {
142142
// Recurse for each cluster which is larger than targetSize
143143
// Give ourselves 30% margin for the target size
144144
if (100 * clusterSizes[c] > 134 * targetSize) {
@@ -152,7 +152,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
152152
return kMeansResult;
153153
}
154154

155-
static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, FloatVectorValuesSlice vectors, short[] assignments) {
155+
static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, FloatVectorValuesSlice vectors, int[] assignments) {
156156
int[] slice = new int[clusterSize];
157157
int idx = 0;
158158
for (int i = 0; i < assignments.length; i++) {
@@ -165,7 +165,7 @@ static FloatVectorValuesSlice createClusterSlice(int clusterSize, int cluster, F
165165
return new FloatVectorValuesSlice(vectors, slice);
166166
}
167167

168-
static void updateAssignmentsWithRecursiveSplit(KMeansResult current, short cluster, KMeansResult splitClusters) {
168+
static void updateAssignmentsWithRecursiveSplit(KMeansResult current, int cluster, KMeansResult splitClusters) {
169169
int orgCentroidsSize = current.centroids().length;
170170

171171
// update based on the outcomes from the split clusters recursion
@@ -175,7 +175,7 @@ static void updateAssignmentsWithRecursiveSplit(KMeansResult current, short clus
175175
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
176176

177177
// replace the original cluster
178-
short origCentroidOrd = 0;
178+
int origCentroidOrd = 0;
179179
newCentroids[cluster] = splitClusters.centroids()[0];
180180

181181
// append the remainder
@@ -188,7 +188,7 @@ static void updateAssignmentsWithRecursiveSplit(KMeansResult current, short clus
188188
if (splitClusters.assignments()[i] != origCentroidOrd) {
189189
int parentOrd = splitClusters.assignmentOrds()[i];
190190
assert current.assignments()[parentOrd] == cluster;
191-
current.assignments()[parentOrd] = (short) (splitClusters.assignments()[i] + orgCentroidsSize - 1);
191+
current.assignments()[parentOrd] = splitClusters.assignments()[i] + orgCentroidsSize - 1;
192192
}
193193
}
194194
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeans.java

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,16 @@ public static float[][] pickInitialCentroids(FloatVectorValues vectors, int samp
7070
return centroids;
7171
}
7272

73-
private boolean stepLloyd(
74-
FloatVectorValues vectors,
75-
float[][] centroids,
76-
short[] assignments,
77-
int sampleSize,
78-
ClusteringAugment augment
79-
) throws IOException {
73+
private boolean stepLloyd(FloatVectorValues vectors, float[][] centroids, int[] assignments, int sampleSize, ClusteringAugment augment)
74+
throws IOException {
8075
boolean changed = false;
8176
int dim = vectors.dimension();
8277
long[] centroidCounts = new long[centroids.length];
8378
float[][] nextCentroids = new float[centroids.length][dim];
8479

8580
for (int i = 0; i < sampleSize; i++) {
8681
float[] vector = vectors.vectorValue(i);
87-
short bestCentroidOffset = getBestCentroidOffset(centroids, vector, i, augment);
82+
int bestCentroidOffset = getBestCentroidOffset(centroids, vector, i, augment);
8883
if (assignments[i] != bestCentroidOffset) {
8984
changed = true;
9085
}
@@ -98,7 +93,7 @@ private boolean stepLloyd(
9893
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
9994
if (centroidCounts[clusterIdx] > 0) {
10095
float countF = (float) centroidCounts[clusterIdx];
101-
for (int d = 0; d < dim; d++) {
96+
for (short d = 0; d < dim; d++) {
10297
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
10398
}
10499
}
@@ -107,10 +102,10 @@ private boolean stepLloyd(
107102
return changed;
108103
}
109104

110-
short getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
111-
short bestCentroidOffset = -1;
105+
int getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
106+
int bestCentroidOffset = -1;
112107
float minDsq = Float.MAX_VALUE;
113-
for (short j = 0; j < centroids.length; j++) {
108+
for (int j = 0; j < centroids.length; j++) {
114109
float dsq = VectorUtil.squareDistance(vector, centroids[j]);
115110
if (dsq < minDsq) {
116111
minDsq = dsq;
@@ -142,7 +137,7 @@ void cluster(FloatVectorValues vectors, KMeansResult kMeansResult, ClusteringAug
142137
return;
143138
}
144139

145-
short[] assignments = new short[n];
140+
int[] assignments = new int[n];
146141
for (int i = 0; i < maxIterations; i++) {
147142
if (stepLloyd(vectors, centroids, assignments, sampleSize, augment) == false) {
148143
break;

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
*/
2626
class KMeansLocal extends KMeans {
2727

28-
final short clustersPerNeighborhood;
28+
final int clustersPerNeighborhood;
2929

30-
KMeansLocal(int sampleSize, int maxIterations, short clustersPerNeighborhood) {
30+
KMeansLocal(int sampleSize, int maxIterations, int clustersPerNeighborhood) {
3131
super(sampleSize, maxIterations);
3232
this.clustersPerNeighborhood = clustersPerNeighborhood;
3333
}
@@ -66,27 +66,27 @@ private void computeNeighborhoods(
6666
}
6767

6868
@Override
69-
short getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
69+
int getBestCentroidOffset(float[][] centroids, float[] vector, int vectorIdx, ClusteringAugment augment) {
7070
assert augment instanceof NeighborsClusteringAugment;
7171

72-
short centroidIdx = ((NeighborsClusteringAugment) augment).getCentroidIdx(vectorIdx);
72+
int centroidIdx = ((NeighborsClusteringAugment) augment).getCentroidIdx(vectorIdx);
7373
List<int[]> neighborhoods = ((NeighborsClusteringAugment) augment).neighborhoods;
7474

75-
short bestCentroidOffset = centroidIdx;
75+
int bestCentroidOffset = centroidIdx;
7676
float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
7777

7878
int[] neighborOffsets = neighborhoods.get(centroidIdx);
7979
for (int neighborOffset : neighborOffsets) {
8080
float dsq = VectorUtil.squareDistance(vector, centroids[neighborOffset]);
8181
if (dsq < minDsq) {
8282
minDsq = dsq;
83-
bestCentroidOffset = (short) neighborOffset;
83+
bestCentroidOffset = neighborOffset;
8484
}
8585
}
8686
return bestCentroidOffset;
8787
}
8888

89-
private short[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, short[] assignments)
89+
private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods, float[][] centroids, int[] assignments)
9090
throws IOException {
9191
// SOAR uses an adjusted distance for assigning spilled documents which is
9292
// given by:
@@ -97,15 +97,15 @@ private short[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoo
9797
// centroid the document was assigned to. The document is assigned to the
9898
// cluster with the smallest soar(x, c).
9999

100-
short[] spilledAssignments = new short[assignments.length];
100+
int[] spilledAssignments = new int[assignments.length];
101101

102102
float[] diffs = new float[vectors.dimension()];
103103
for (int i = 0; i < vectors.size(); i++) {
104104
float[] vector = vectors.vectorValue(i);
105105

106-
short currAssignment = assignments[i];
106+
int currAssignment = assignments[i];
107107
float[] currentCentroid = centroids[currAssignment];
108-
for (int j = 0; j < vectors.dimension(); j++) {
108+
for (short j = 0; j < vectors.dimension(); j++) {
109109
float diff = vector[j] - currentCentroid[j];
110110
diffs[j] = diff;
111111
}
@@ -128,7 +128,7 @@ private short[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoo
128128
}
129129
}
130130

131-
spilledAssignments[i] = (short) bestAssignment;
131+
spilledAssignments[i] = bestAssignment;
132132
}
133133

134134
return spilledAssignments;
@@ -156,7 +156,7 @@ private float distanceSoar(float[] residual, float[] vector, float[] centroid, f
156156
@Override
157157
void cluster(FloatVectorValues vectors, KMeansResult kMeansResult) throws IOException {
158158
float[][] centroids = kMeansResult.centroids();
159-
short[] assignments = kMeansResult.assignments();
159+
int[] assignments = kMeansResult.assignments();
160160

161161
assert assignments != null;
162162
assert assignments.length == vectors.size();
@@ -174,14 +174,14 @@ void cluster(FloatVectorValues vectors, KMeansResult kMeansResult) throws IOExce
174174

175175
static class NeighborsClusteringAugment extends ClusteringAugment {
176176
final List<int[]> neighborhoods;
177-
final short[] assignments;
177+
final int[] assignments;
178178

179-
NeighborsClusteringAugment(short[] assignments, List<int[]> neighborhoods) {
179+
NeighborsClusteringAugment(int[] assignments, List<int[]> neighborhoods) {
180180
this.neighborhoods = neighborhoods;
181181
this.assignments = assignments;
182182
}
183183

184-
public short getCentroidIdx(int vectorIdx) {
184+
public int getCentroidIdx(int vectorIdx) {
185185
return this.assignments[vectorIdx];
186186
}
187187
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansResult.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
*/
1515
public class KMeansResult {
1616
private float[][] centroids;
17-
private final short[] assignments;
17+
private final int[] assignments;
1818
private final int[] assignmentOrds;
19-
private short[] soarAssignments;
19+
private int[] soarAssignments;
2020

21-
KMeansResult(float[][] centroids, short[] assignments, int[] assignmentOrds, short[] soarAssignments) {
21+
KMeansResult(float[][] centroids, int[] assignments, int[] assignmentOrds, int[] soarAssignments) {
2222
assert centroids != null;
2323
assert assignments != null;
2424
assert assignmentOrds != null;
@@ -29,20 +29,20 @@ public class KMeansResult {
2929
this.soarAssignments = soarAssignments;
3030
}
3131

32-
KMeansResult(float[][] centroids, short[] assignments, int[] assignmentOrdinals) {
33-
this(centroids, assignments, assignmentOrdinals, new short[0]);
32+
KMeansResult(float[][] centroids, int[] assignments, int[] assignmentOrdinals) {
33+
this(centroids, assignments, assignmentOrdinals, new int[0]);
3434
}
3535

3636
KMeansResult() {
37-
this(new float[0][0], new short[0], new int[0], new short[0]);
37+
this(new float[0][0], new int[0], new int[0], new int[0]);
3838
}
3939

4040
KMeansResult(float[][] centroids) {
41-
this(centroids, new short[0], new int[0], new short[0]);
41+
this(centroids, new int[0], new int[0], new int[0]);
4242
}
4343

44-
KMeansResult(float[][] centroids, short[] assignments) {
45-
this(centroids, assignments, new int[0], new short[0]);
44+
KMeansResult(float[][] centroids, int[] assignments) {
45+
this(centroids, assignments, new int[0], new int[0]);
4646
}
4747

4848
public float[][] centroids() {
@@ -53,19 +53,19 @@ public void setCentroids(float[][] centroids) {
5353
this.centroids = centroids;
5454
}
5555

56-
public short[] assignments() {
56+
public int[] assignments() {
5757
return assignments;
5858
}
5959

6060
public int[] assignmentOrds() {
6161
return assignmentOrds;
6262
}
6363

64-
public short[] soarAssignments() {
64+
public int[] soarAssignments() {
6565
return soarAssignments;
6666
}
6767

68-
public void setSoarAssignments(short[] soarAssignments) {
68+
public void setSoarAssignments(int[] soarAssignments) {
6969
this.soarAssignments = soarAssignments;
7070
}
7171

0 commit comments

Comments
 (0)