|
11 | 11 |
|
12 | 12 | import org.apache.lucene.index.VectorSimilarityFunction; |
13 | 13 | import org.apache.lucene.search.KnnCollector; |
14 | | -import org.apache.lucene.search.ScoreDoc; |
15 | | -import org.apache.lucene.util.Bits; |
| 14 | +import org.apache.lucene.search.TopDocs; |
| 15 | +import org.apache.lucene.search.knn.KnnSearchStrategy; |
16 | 16 | import org.apache.lucene.util.VectorUtil; |
17 | 17 | import org.apache.lucene.util.hnsw.HnswGraphBuilder; |
18 | 18 | import org.apache.lucene.util.hnsw.HnswGraphSearcher; |
|
25 | 25 |
|
26 | 26 | public record NeighborHood(int[] neighbors, float maxIntraDistance) { |
27 | 27 |
|
| 28 | + private static final int M = 8; |
| 29 | + private static final int EF_CONSTRUCTION = 150; |
| 30 | + |
28 | 31 | static final NeighborHood EMPTY = new NeighborHood(new int[0], Float.POSITIVE_INFINITY); |
29 | 32 |
|
| 33 | + public static NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNeighborhood) throws IOException { |
| 34 | + assert centers.length > clustersPerNeighborhood; |
| 35 | + // experiments shows that below 15k, we better use brute force, otherwise hnsw gives us a nice speed up |
| 36 | + if (centers.length < 15_000) { |
| 37 | + return computeNeighborhoodsBruteForce(centers, clustersPerNeighborhood); |
| 38 | + } else { |
| 39 | + return computeNeighborhoodsGraph(centers, clustersPerNeighborhood); |
| 40 | + } |
| 41 | + } |
| 42 | + |
| 43 | + public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) { |
| 44 | + int k = centers.length; |
| 45 | + NeighborQueue[] neighborQueues = new NeighborQueue[k]; |
| 46 | + for (int i = 0; i < k; i++) { |
| 47 | + neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); |
| 48 | + } |
| 49 | + final float[] scores = new float[4]; |
| 50 | + final int limit = k - 3; |
| 51 | + for (int i = 0; i < k - 1; i++) { |
| 52 | + float[] center = centers[i]; |
| 53 | + int j = i + 1; |
| 54 | + for (; j < limit; j += 4) { |
| 55 | + ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores); |
| 56 | + for (int h = 0; h < 4; h++) { |
| 57 | + neighborQueues[j + h].insertWithOverflow(i, scores[h]); |
| 58 | + neighborQueues[i].insertWithOverflow(j + h, scores[h]); |
| 59 | + } |
| 60 | + } |
| 61 | + for (; j < k; j++) { |
| 62 | + float dsq = VectorUtil.squareDistance(center, centers[j]); |
| 63 | + neighborQueues[j].insertWithOverflow(i, dsq); |
| 64 | + neighborQueues[i].insertWithOverflow(j, dsq); |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + NeighborHood[] neighborhoods = new NeighborHood[k]; |
| 69 | + for (int i = 0; i < k; i++) { |
| 70 | + NeighborQueue queue = neighborQueues[i]; |
| 71 | + if (queue.size() == 0) { |
| 72 | + // no neighbors, skip |
| 73 | + neighborhoods[i] = NeighborHood.EMPTY; |
| 74 | + continue; |
| 75 | + } |
| 76 | + // consume the queue into the neighbors array and get the maximum intra-cluster distance |
| 77 | + int[] neighbors = new int[queue.size()]; |
| 78 | + float maxIntraDistance = queue.topScore(); |
| 79 | + int iter = 0; |
| 80 | + while (queue.size() > 0) { |
| 81 | + neighbors[neighbors.length - ++iter] = queue.pop(); |
| 82 | + } |
| 83 | + neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance); |
| 84 | + } |
| 85 | + return neighborhoods; |
| 86 | + } |
| 87 | + |
30 | 88 | public static NeighborHood[] computeNeighborhoodsGraph(float[][] centers, int clustersPerNeighborhood) throws IOException { |
31 | 89 | final UpdateableRandomVectorScorer scorer = new UpdateableRandomVectorScorer() { |
32 | 90 | int scoringOrdinal; |
@@ -57,92 +115,97 @@ public RandomVectorScorerSupplier copy() { |
57 | 115 | return this; |
58 | 116 | } |
59 | 117 | }; |
60 | | - final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, 16, 100, 42L).build(centers.length); |
| 118 | + final OnHeapHnswGraph graph = HnswGraphBuilder.create(supplier, M, EF_CONSTRUCTION, 42L).build(centers.length); |
61 | 119 | final NeighborHood[] neighborhoods = new NeighborHood[centers.length]; |
62 | | - final SingleBit singleBit = new SingleBit(centers.length); |
| 120 | + // oversample the number of neighbors we collect to improve recall |
| 121 | + final ReusableKnnCollector collector = new ReusableKnnCollector(2 * clustersPerNeighborhood); |
63 | 122 | for (int i = 0; i < centers.length; i++) { |
| 123 | + collector.reset(i); |
64 | 124 | scorer.setScoringOrdinal(i); |
65 | | - singleBit.indexSet = i; |
66 | | - final KnnCollector collector = HnswGraphSearcher.search(scorer, clustersPerNeighborhood, graph, singleBit, Integer.MAX_VALUE); |
67 | | - final ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs; |
68 | | - if (scoreDocs.length == 0) { |
| 125 | + HnswGraphSearcher.search(scorer, collector, graph, null); |
| 126 | + NeighborQueue queue = collector.queue; |
| 127 | + if (queue.size() == 0) { |
69 | 128 | // no neighbors, skip |
70 | 129 | neighborhoods[i] = NeighborHood.EMPTY; |
71 | 130 | continue; |
72 | 131 | } |
73 | | - final int[] neighbors = new int[scoreDocs.length]; |
74 | | - for (int j = 0; j < neighbors.length; j++) { |
75 | | - neighbors[j] = scoreDocs[j].doc; |
76 | | - assert neighbors[j] != i; |
| 132 | + while (queue.size() > clustersPerNeighborhood) { |
| 133 | + queue.pop(); |
| 134 | + } |
| 135 | + final float minScore = queue.topScore(); |
| 136 | + final int[] neighbors = new int[queue.size()]; |
| 137 | + for (int j = 1; j <= neighbors.length; j++) { |
| 138 | + neighbors[neighbors.length - j] = queue.pop(); |
77 | 139 | } |
78 | | - final float minCompetitiveSimilarity = (1f / scoreDocs[neighbors.length - 1].score) - 1; |
79 | | - neighborhoods[i] = new NeighborHood(neighbors, minCompetitiveSimilarity); |
| 140 | + neighborhoods[i] = new NeighborHood(neighbors, (1f / minScore) - 1); |
80 | 141 | } |
81 | 142 | return neighborhoods; |
82 | 143 | } |
83 | 144 |
|
84 | | - private static class SingleBit implements Bits { |
| 145 | + private static class ReusableKnnCollector implements KnnCollector { |
85 | 146 |
|
86 | | - private final int length; |
87 | | - private int indexSet; |
| 147 | + private final NeighborQueue queue; |
| 148 | + private final int k; |
| 149 | + int visitedCount; |
| 150 | + int currenOrd; |
88 | 151 |
|
89 | | - SingleBit(int length) { |
90 | | - this.length = length; |
| 152 | + ReusableKnnCollector(int k) { |
| 153 | + this.k = k; |
| 154 | + this.queue = new NeighborQueue(k, false); |
| 155 | + } |
| 156 | + |
| 157 | + void reset(int ord) { |
| 158 | + queue.clear(); |
| 159 | + visitedCount = 0; |
| 160 | + currenOrd = ord; |
91 | 161 | } |
92 | 162 |
|
93 | 163 | @Override |
94 | | - public boolean get(int index) { |
95 | | - return index != indexSet; |
| 164 | + public boolean earlyTerminated() { |
| 165 | + return false; |
96 | 166 | } |
97 | 167 |
|
98 | 168 | @Override |
99 | | - public int length() { |
100 | | - return length; |
| 169 | + public void incVisitedCount(int count) { |
| 170 | + visitedCount += count; |
101 | 171 | } |
102 | | - } |
103 | 172 |
|
104 | | - public static NeighborHood[] computeNeighborhoodsBruteForce(float[][] centers, int clustersPerNeighborhood) { |
105 | | - int k = centers.length; |
106 | | - NeighborQueue[] neighborQueues = new NeighborQueue[k]; |
107 | | - for (int i = 0; i < k; i++) { |
108 | | - neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); |
| 173 | + @Override |
| 174 | + public long visitedCount() { |
| 175 | + return visitedCount; |
109 | 176 | } |
110 | | - final float[] scores = new float[4]; |
111 | | - final int limit = k - 3; |
112 | | - for (int i = 0; i < k - 1; i++) { |
113 | | - float[] center = centers[i]; |
114 | | - int j = i + 1; |
115 | | - for (; j < limit; j += 4) { |
116 | | - ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores); |
117 | | - for (int h = 0; h < 4; h++) { |
118 | | - neighborQueues[j + h].insertWithOverflow(i, scores[h]); |
119 | | - neighborQueues[i].insertWithOverflow(j + h, scores[h]); |
120 | | - } |
121 | | - } |
122 | | - for (; j < k; j++) { |
123 | | - float dsq = VectorUtil.squareDistance(center, centers[j]); |
124 | | - neighborQueues[j].insertWithOverflow(i, dsq); |
125 | | - neighborQueues[i].insertWithOverflow(j, dsq); |
126 | | - } |
| 177 | + |
| 178 | + @Override |
| 179 | + public long visitLimit() { |
| 180 | + return Integer.MAX_VALUE; |
127 | 181 | } |
128 | 182 |
|
129 | | - NeighborHood[] neighborhoods = new NeighborHood[k]; |
130 | | - for (int i = 0; i < k; i++) { |
131 | | - NeighborQueue queue = neighborQueues[i]; |
132 | | - if (queue.size() == 0) { |
133 | | - // no neighbors, skip |
134 | | - neighborhoods[i] = NeighborHood.EMPTY; |
135 | | - continue; |
136 | | - } |
137 | | - // consume the queue into the neighbors array and get the maximum intra-cluster distance |
138 | | - int[] neighbors = new int[queue.size()]; |
139 | | - float maxIntraDistance = queue.topScore(); |
140 | | - int iter = 0; |
141 | | - while (queue.size() > 0) { |
142 | | - neighbors[neighbors.length - ++iter] = queue.pop(); |
| 183 | + @Override |
| 184 | + public int k() { |
| 185 | + return k; |
| 186 | + } |
| 187 | + |
| 188 | + @Override |
| 189 | + public boolean collect(int docId, float similarity) { |
| 190 | + if (currenOrd != docId) { |
| 191 | + return queue.insertWithOverflow(docId, similarity); |
143 | 192 | } |
144 | | - neighborhoods[i] = new NeighborHood(neighbors, maxIntraDistance); |
| 193 | + return false; |
| 194 | + } |
| 195 | + |
| 196 | + @Override |
| 197 | + public float minCompetitiveSimilarity() { |
| 198 | + return queue.size() >= k() ? queue.topScore() : Float.NEGATIVE_INFINITY; |
| 199 | + } |
| 200 | + |
| 201 | + @Override |
| 202 | + public TopDocs topDocs() { |
| 203 | + throw new UnsupportedOperationException(); |
| 204 | + } |
| 205 | + |
| 206 | + @Override |
| 207 | + public KnnSearchStrategy getSearchStrategy() { |
| 208 | + return null; |
145 | 209 | } |
146 | | - return neighborhoods; |
147 | 210 | } |
148 | 211 | } |
0 commit comments