Skip to content

Commit 9a2c5b7

Browse files
committed
iter
1 parent ee61553 commit 9a2c5b7

File tree

1 file changed

+10
-10
lines changed
  • server/src/main/java/org/elasticsearch/index/codec/vectors/cluster

1 file changed

+10
-10
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,15 @@ private static boolean stepLloyd(
100100
}
101101
}
102102
if (changed) {
103-
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
104-
if (centroidChanged.get(clusterIdx)) {
105-
Arrays.fill(centroids[clusterIdx], 0.0f);
106-
}
107-
}
108103
Arrays.fill(centroidCounts, 0);
109104
for (int idx = 0; idx < vectors.size(); idx++) {
110105
final int assignment = assignments[translateOrd.apply(idx)];
111106
if (centroidChanged.get(assignment)) {
112-
centroidCounts[assignment]++;
113-
float[] vector = vectors.vectorValue(idx);
114107
float[] centroid = centroids[assignment];
108+
if (centroidCounts[assignment]++ == 0) {
109+
Arrays.fill(centroid, 0.0f);
110+
}
111+
float[] vector = vectors.vectorValue(idx);
115112
for (int d = 0; d < dim; d++) {
116113
centroid[d] += vector[d];
117114
}
@@ -120,9 +117,12 @@ private static boolean stepLloyd(
120117

121118
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
122119
if (centroidChanged.get(clusterIdx)) {
123-
float countF = (float) centroidCounts[clusterIdx];
124-
for (int d = 0; d < dim; d++) {
125-
centroids[clusterIdx][d] /= countF;
120+
float count = (float) centroidCounts[clusterIdx];
121+
if (count > 0) {
122+
float[] centroid = centroids[clusterIdx];
123+
for (int d = 0; d < dim; d++) {
124+
centroid[d] /= count;
125+
}
126126
}
127127
}
128128
}

0 commit comments

Comments
 (0)