|
20 | 20 | import java.util.Set; |
21 | 21 |
|
22 | 22 | import static org.hamcrest.Matchers.containsString; |
23 | | -import static org.hamcrest.Matchers.greaterThan; |
| 23 | +import static org.hamcrest.Matchers.greaterThanOrEqualTo; |
24 | 24 |
|
25 | 25 | public class KMeansLocalTests extends ESTestCase { |
26 | 26 |
|
@@ -146,23 +146,23 @@ private static FloatVectorValues generateData(int nSamples, int nDims, int nClus |
146 | 146 | } |
147 | 147 |
|
148 | 148 | public void testComputeNeighbours() throws IOException { |
149 | | - int numCentroids = randomIntBetween(100, 10000); |
| 149 | + int numCentroids = randomIntBetween(100, 1000); |
150 | 150 | int dims = randomIntBetween(10, 200); |
151 | 151 | float[][] vectors = new float[numCentroids][dims]; |
152 | 152 | for (int i = 0; i < numCentroids; i++) { |
153 | 153 | for (int j = 0; j < dims; j++) { |
154 | 154 | vectors[i][j] = randomFloat(); |
155 | 155 | } |
156 | 156 | } |
157 | | - int clustersPerNeighbour = randomIntBetween(6, 32); |
158 | | - KMeansLocal.NeighborHood[] neighborHoodsGraph = KMeansLocal.computeNeighborhoodsGraph(vectors, clustersPerNeighbour); |
159 | | - KMeansLocal.NeighborHood[] neighborHoodsBruteForce = KMeansLocal.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour); |
| 157 | + int clustersPerNeighbour = randomIntBetween(32, 64); |
| 158 | + NeighborHood[] neighborHoodsGraph = NeighborHood.computeNeighborhoodsGraph(vectors, clustersPerNeighbour); |
| 159 | + NeighborHood[] neighborHoodsBruteForce = NeighborHood.computeNeighborhoodsBruteForce(vectors, clustersPerNeighbour); |
160 | 160 | assertEquals(neighborHoodsGraph.length, neighborHoodsBruteForce.length); |
161 | 161 | for (int i = 0; i < neighborHoodsGraph.length; i++) { |
162 | 162 | assertEquals(neighborHoodsBruteForce[i].neighbors().length, neighborHoodsGraph[i].neighbors().length); |
163 | 163 | int matched = compareNN(i, neighborHoodsBruteForce[i].neighbors(), neighborHoodsGraph[i].neighbors()); |
164 | 164 | double recall = (double) matched / neighborHoodsGraph[i].neighbors().length; |
165 | | - assertThat(recall, greaterThan(0.4)); |
| 165 | + assertThat(recall, greaterThanOrEqualTo(0.7)); |
166 | 166 | if (recall == 1.0) { |
167 | 167 | // we cannot assert on array equality as there can be small differences due to numerical errors |
168 | 168 | assertEquals(neighborHoodsBruteForce[i].maxIntraDistance(), neighborHoodsGraph[i].maxIntraDistance(), 1e-5f); |
|
0 commit comments