10
10
package org .elasticsearch .index .codec .vectors .cluster ;
11
11
12
12
import org .apache .lucene .index .FloatVectorValues ;
13
+ import org .apache .lucene .util .FixedBitSet ;
13
14
import org .apache .lucene .util .VectorUtil ;
14
15
import org .apache .lucene .util .hnsw .IntToIntFunction ;
15
16
import org .elasticsearch .index .codec .vectors .SampleReader ;
@@ -70,17 +71,14 @@ private static boolean stepLloyd(
70
71
FloatVectorValues vectors ,
71
72
IntToIntFunction translateOrd ,
72
73
float [][] centroids ,
73
- float [][] nextCentroids ,
74
+ FixedBitSet centroidChanged ,
75
+ int [] centroidCounts ,
74
76
int [] assignments ,
75
77
NeighborHood [] neighborhoods
76
78
) throws IOException {
77
79
boolean changed = false ;
78
80
int dim = vectors .dimension ();
79
- int [] centroidCounts = new int [centroids .length ];
80
-
81
- for (float [] nextCentroid : nextCentroids ) {
82
- Arrays .fill (nextCentroid , 0.0f );
83
- }
81
+ centroidChanged .clear ();
84
82
final float [] distances = new float [4 ];
85
83
for (int idx = 0 ; idx < vectors .size (); idx ++) {
86
84
float [] vector = vectors .vectorValue (idx );
@@ -93,20 +91,39 @@ private static boolean stepLloyd(
93
91
bestCentroidOffset = getBestCentroid (centroids , vector , distances );
94
92
}
95
93
if (assignment != bestCentroidOffset ) {
94
+ if (assignment != -1 ) {
95
+ centroidChanged .set (assignment );
96
+ }
97
+ centroidChanged .set (bestCentroidOffset );
96
98
assignments [vectorOrd ] = bestCentroidOffset ;
97
99
changed = true ;
98
100
}
99
- centroidCounts [bestCentroidOffset ]++;
100
- for (int d = 0 ; d < dim ; d ++) {
101
- nextCentroids [bestCentroidOffset ][d ] += vector [d ];
102
- }
103
101
}
102
+ if (changed ) {
103
+ Arrays .fill (centroidCounts , 0 );
104
+ for (int idx = 0 ; idx < vectors .size (); idx ++) {
105
+ final int assignment = assignments [translateOrd .apply (idx )];
106
+ if (centroidChanged .get (assignment )) {
107
+ float [] centroid = centroids [assignment ];
108
+ if (centroidCounts [assignment ]++ == 0 ) {
109
+ Arrays .fill (centroid , 0.0f );
110
+ }
111
+ float [] vector = vectors .vectorValue (idx );
112
+ for (int d = 0 ; d < dim ; d ++) {
113
+ centroid [d ] += vector [d ];
114
+ }
115
+ }
116
+ }
104
117
105
- for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
106
- if (centroidCounts [clusterIdx ] > 0 ) {
107
- float countF = (float ) centroidCounts [clusterIdx ];
108
- for (int d = 0 ; d < dim ; d ++) {
109
- centroids [clusterIdx ][d ] = nextCentroids [clusterIdx ][d ] / countF ;
118
+ for (int clusterIdx = 0 ; clusterIdx < centroids .length ; clusterIdx ++) {
119
+ if (centroidChanged .get (clusterIdx )) {
120
+ float count = (float ) centroidCounts [clusterIdx ];
121
+ if (count > 0 ) {
122
+ float [] centroid = centroids [clusterIdx ];
123
+ for (int d = 0 ; d < dim ; d ++) {
124
+ centroid [d ] /= count ;
125
+ }
126
+ }
110
127
}
111
128
}
112
129
}
@@ -420,17 +437,18 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
420
437
}
421
438
422
439
assert assignments .length == n ;
423
- float [][] nextCentroids = new float [centroids .length ][vectors .dimension ()];
440
+ FixedBitSet centroidChanged = new FixedBitSet (centroids .length );
441
+ int [] centroidCounts = new int [centroids .length ];
424
442
for (int i = 0 ; i < maxIterations ; i ++) {
425
443
// This is potentially sampled, so we need to translate ordinals
426
- if (stepLloyd (sampledVectors , translateOrd , centroids , nextCentroids , assignments , neighborhoods ) == false ) {
444
+ if (stepLloyd (sampledVectors , translateOrd , centroids , centroidChanged , centroidCounts , assignments , neighborhoods ) == false ) {
427
445
break ;
428
446
}
429
447
}
430
448
// If we were sampled, do a once over the full set of vectors to finalize the centroids
431
449
if (sampleSize < n || maxIterations == 0 ) {
432
450
// No ordinal translation needed here, we are using the full set of vectors
433
- stepLloyd (vectors , i -> i , centroids , nextCentroids , assignments , neighborhoods );
451
+ stepLloyd (vectors , i -> i , centroids , centroidChanged , centroidCounts , assignments , neighborhoods );
434
452
}
435
453
}
436
454
0 commit comments