Skip to content

Commit 651e3bf

Browse files
committed
fix test
1 parent b153b8e commit 651e3bf

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

server/src/test/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocalTests.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import java.util.Set;
2121

2222
import static org.hamcrest.Matchers.containsString;
23-
import static org.hamcrest.Matchers.greaterThan;
23+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
2424

2525
public class KMeansLocalTests extends ESTestCase {
2626

@@ -146,23 +146,23 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus
146146
}
147147

148148
public void testComputeNeighbours() throws IOException {
149-
int numCentroids = randomIntBetween(100, 10000);
149+
int numCentroids = randomIntBetween(100, 1000);
150150
int dims = randomIntBetween(10, 200);
151151
float[][] vectors = new float[numCentroids][dims];
152152
for (int i = 0; i < numCentroids; i++) {
153153
for (int j = 0; j < dims; j++) {
154154
vectors[i][j] = randomFloat();
155155
}
156156
}
157-
int clustersPerNeighbour = randomIntBetween(6, 32);
158-
KMeansLocal.NeighborHood[] neighborHoodsGraph = KMeansLocal.computeNeighborhoodsGraph(vectors, clustersPerNeighbour);
159-
KMeansLocal.NeighborHood[] neighborHoodsBruteForce = KMeansLocal.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour);
157+
int clustersPerNeighbour = randomIntBetween(32, 64);
158+
NeighborHood[] neighborHoodsGraph = NeighborHood.computeNeighborhoodsGraph(vectors, clustersPerNeighbour);
159+
NeighborHood[] neighborHoodsBruteForce = NeighborHood.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour);
160160
assertEquals(neighborHoodsGraph.length, neighborHoodsBruteForce.length);
161161
for (int i = 0; i < neighborHoodsGraph.length; i++) {
162162
assertEquals(neighborHoodsBruteForce[i].neighbors().length, neighborHoodsGraph[i].neighbors().length);
163163
int matched = compareNN(i, neighborHoodsBruteForce[i].neighbors(), neighborHoodsGraph[i].neighbors());
164164
double recall = (double) matched / neighborHoodsGraph[i].neighbors().length;
165-
assertThat(recall, greaterThan(0.4));
165+
assertThat(recall, greaterThanOrEqualTo(0.7));
166166
if (recall == 1.0) {
167167
// we cannot assert on array equality as there can be small differences due to numerical errors
168168
assertEquals(neighborHoodsBruteForce[i].maxIntraDistance(), neighborHoodsGraph[i].maxIntraDistance(), 1e-5f);

0 commit comments

Comments
 (0)