|
48 | 48 | import java.io.IOException; |
49 | 49 | import java.util.ArrayList; |
50 | 50 | import java.util.Arrays; |
| 51 | +import java.util.Collections; |
51 | 52 | import java.util.HashMap; |
52 | 53 | import java.util.List; |
53 | 54 | import java.util.Map; |
@@ -130,42 +131,54 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { |
130 | 131 | KnnCollectorManager knnCollectorManager = getKnnCollectorManager(numCands, indexSearcher); |
131 | 132 | TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); |
132 | 133 | List<LeafReaderContext> leafReaderContexts = reader.leaves(); |
| 134 | + List<Callable<TopDocs>> tasks; |
| 135 | + if (leafReaderContexts.isEmpty() == false) { |
| 136 | + |
| 137 | + // calculate the affinity of each segment to the query vector |
| 138 | + // (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.) |
| 139 | + List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector()); |
| 140 | + |
| 141 | + // TODO: sort segments by affinity score in descending order, and cut the long tail ? |
| 142 | + double[] affinityScores = segmentAffinities.stream().map(SegmentAffinity::affinityScore).mapToDouble(Double::doubleValue).toArray(); |
| 143 | + |
| 144 | + // max affinity for decreasing nProbe |
| 145 | + double averageAffinity = Arrays.stream(affinityScores).average().orElse(Double.NaN); |
| 146 | + double maxAffinity = Arrays.stream(affinityScores).max().orElse(Double.NaN); |
| 147 | + double lowerAffinity = (maxAffinity + averageAffinity) * 0.5; |
| 148 | + double cutoffAffinity = lowerAffinity * 0.5; // minimum affinity score for a segment to be considered |
| 149 | + double affinityTreshold = (maxAffinity + lowerAffinity) * 0.66; // min affinity for increasing nProbe |
| 150 | + int maxAdjustments = (int) (nProbe * 1.5); |
| 151 | + |
| 152 | + if (Double.isNaN(maxAffinity) || Double.isNaN(averageAffinity)) { |
| 153 | + tasks = new ArrayList<>(leafReaderContexts.size()); |
| 154 | + for (LeafReaderContext context : leafReaderContexts) { |
| 155 | + tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager, nProbe)); |
| 156 | + } |
| 157 | + } else { |
| 158 | + Map<LeafReaderContext, Integer> segmentNProbeMap = new HashMap<>(); |
| 159 | + // process segments based on their affinity scores |
| 160 | + for (SegmentAffinity affinity : segmentAffinities) { |
| 161 | + double score = affinity.affinityScore(); |
| 162 | + |
| 163 | + // skip segments with very low affinity |
| 164 | + if (score < cutoffAffinity) { |
| 165 | + continue; |
| 166 | + } |
133 | 167 |
|
134 | | - // calculate the affinity of each segment to the query vector |
135 | | - // (need information from each segment: no. of clusters, global centroid, density, parent centroids' scores, etc.) |
136 | | - List<SegmentAffinity> segmentAffinities = calculateSegmentAffinities(leafReaderContexts, getQueryVector()); |
137 | | - |
138 | | - // TODO: sort segments by affinity score in descending order, and cut the long tail ? |
139 | | - double[] affinityScores = segmentAffinities.stream().map(SegmentAffinity::affinityScore).mapToDouble(Double::doubleValue).toArray(); |
140 | | - |
141 | | - // max affinity for decreasing nProbe |
142 | | - double average = Arrays.stream(affinityScores).average().orElse(0.0); |
143 | | - double maxAffinity = Arrays.stream(affinityScores).max().orElse(0.0); |
144 | | - double lowerAffinity = (maxAffinity + average) * 0.5; |
145 | | - double cutoffAffinity = lowerAffinity * 0.5; // minimum affinity score for a segment to be considered |
146 | | - double affinityTreshold = (maxAffinity + lowerAffinity) * 0.66; // min affinity for increasing nProbe |
147 | | - int maxAdjustments = (int) (nProbe * 1.5); |
| 168 | + // adjust nProbe based on affinity score, with larger affinity we increase nprobe (and viceversa) |
| 169 | + int adjustedNProbe = adjustNProbeForSegment(score, affinityTreshold, maxAdjustments); |
148 | 170 |
|
149 | | - Map<LeafReaderContext, Integer> segmentNProbeMap = new HashMap<>(); |
150 | | - // process segments based on their affinity scores |
151 | | - for (SegmentAffinity affinity : segmentAffinities) { |
152 | | - double score = affinity.affinityScore(); |
| 171 | + // store the adjusted nProbe value for this segment |
| 172 | + segmentNProbeMap.put(affinity.context(), adjustedNProbe); |
| 173 | + } |
153 | 174 |
|
154 | | - // skip segments with very low affinity |
155 | | - if (score < cutoffAffinity) { |
156 | | - continue; |
| 175 | + tasks = new ArrayList<>(segmentNProbeMap.size()); |
| 176 | + for (Map.Entry<LeafReaderContext, Integer> entry : segmentNProbeMap.entrySet()) { |
| 177 | + tasks.add(() -> searchLeaf(entry.getKey(), filterWeight, knnCollectorManager, entry.getValue())); |
| 178 | + } |
157 | 179 | } |
158 | | - |
159 | | - // sdjust nProbe based on affinity score, with larger affinity we increase nprobe (and viceversa) |
160 | | - int adjustedNProbe = adjustNProbeForSegment(score, affinityTreshold, maxAdjustments); |
161 | | - |
162 | | - // store the adjusted nProbe value for this segment |
163 | | - segmentNProbeMap.put(affinity.context(), adjustedNProbe); |
164 | | - } |
165 | | - |
166 | | - List<Callable<TopDocs>> tasks = new ArrayList<>(segmentNProbeMap.size()); |
167 | | - for (Map.Entry<LeafReaderContext, Integer> entry : segmentNProbeMap.entrySet()) { |
168 | | - tasks.add(() -> searchLeaf(entry.getKey(), filterWeight, knnCollectorManager, entry.getValue())); |
| 180 | + } else { |
| 181 | + tasks = Collections.emptyList(); |
169 | 182 | } |
170 | 183 | TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); |
171 | 184 |
|
|
0 commit comments