Skip to content

Commit 5910259

Browse files
committed
iterated on better mechanism for utilizing parent centroids
1 parent 2f5bfd3 commit 5910259

File tree

2 files changed

+46
-48
lines changed

2 files changed

+46
-48
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,18 +126,12 @@ public float[] centroid(int centroidOrdinal) throws IOException {
126126
return centroid;
127127
}
128128

129-
@Override
130-
public void bulkScore(NeighborQueue queue) throws IOException {
131-
// TODO: bulk score centroids like we do with posting lists
132-
centroids.seek(quantizedCentroidsOffset);
133-
for (int i = 0; i < numCentroids; i++) {
134-
queue.add(i, score());
135-
}
136-
}
137-
138129
@Override
139130
public void bulkScore(NeighborQueue queue, int start, int end) throws IOException {
140131
// TODO: bulk score centroids like we do with posting lists
132+
assert start > 0;
133+
assert end > 0;
134+
assert start + end <= numCentroids;
141135
centroids.seek(quantizedCentroidsOffset + quantizedVectorByteSize * start);
142136
for (int i = start; i < end; i++) {
143137
queue.add(i, score());
@@ -222,17 +216,11 @@ public int getChildCount(int centroidOrdinal) throws IOException {
222216
return childCount;
223217
}
224218

225-
@Override
226-
public void bulkScore(NeighborQueue queue) throws IOException {
227-
// TODO: bulk score centroids like we do with posting lists
228-
centroids.seek(0L);
229-
for (int i = 0; i < numParentCentroids; i++) {
230-
queue.add(i, score());
231-
}
232-
}
233-
234219
@Override
235220
public void bulkScore(NeighborQueue queue, int start, int end) throws IOException {
221+
assert start > 0;
222+
assert end > 0;
223+
assert start + end <= numParentCentroids;
236224
// TODO: bulk score centroids like we do with posting lists
237225
centroids.seek(parentNodeByteSize * start);
238226
for (int i = start; i < end; i++) {

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -290,45 +290,35 @@ public final void search(String field, float[] target, KnnCollector knnCollector
290290
parentCentroidQueue.add(-1, 0.f);
291291
}
292292

293-
int bulkParentThreshold = (int) Math.ceil(parentCentroidQueue.size() * 0.5);
294-
295293
while (parentCentroidQueue.size() > 0 && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
296294
NeighborQueue centroidQueue = new NeighborQueue(centroidQueryScorer.size(), true);
297-
int parentsToExplore = 0;
298-
while (parentCentroidQueue.size() > 0 && parentsToExplore < bulkParentThreshold) {
299-
int parentCentroidOrdinal = parentCentroidQueue.pop();
300-
301-
int childCentroidOrdinal;
302-
int childCentroidCount;
303-
if (parentCentroidOrdinal == -1) {
304-
// score all centroids
305-
childCentroidOrdinal = 0;
306-
childCentroidCount = centroidQueryScorer.size();
307-
} else {
308-
childCentroidOrdinal = parentCentroidQueryScorer.getChildCentroidStart(parentCentroidOrdinal);
309-
childCentroidCount = parentCentroidQueryScorer.getChildCount(parentCentroidOrdinal);
310-
}
311-
// FIXME: modify scorePostingLists to take a queue instead of creating one
312-
centroidQueryScorer.bulkScore(centroidQueue, childCentroidOrdinal, childCentroidOrdinal + childCentroidCount);
313-
314-
if (parentCentroidOrdinal == -1) {
315-
break;
316-
}
317-
318-
parentsToExplore++;
319-
}
295+
updateCentroidQueueWNextParent(parentCentroidQueryScorer, parentCentroidQueue, centroidQueryScorer, centroidQueue);
320296

321297
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
322298
// initially we visit only the "centroids to search"
323299
// Note, numCollected is doing the bare minimum here.
324300
// TODO do we need to handle nested doc counts similarly to how we handle
325301
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
302+
float nextParentDistance = Float.MAX_VALUE;
303+
if (parentCentroidQueue.size() > 0) {
304+
nextParentDistance = parentCentroidQueue.topScore();
305+
}
326306
while (centroidQueue.size() > 0 && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
327307
++centroidsVisited;
328-
// todo do we actually need to know the score???
308+
float centroidDistance = centroidQueue.topScore();
309+
// the next parent likely contains centroids we need to evaluate prior to evaluating this next centroid
310+
while (parentCentroidQueue.size() > 0 && centroidDistance > nextParentDistance) {
311+
updateCentroidQueueWNextParent(parentCentroidQueryScorer, parentCentroidQueue, centroidQueryScorer, centroidQueue);
312+
if (parentCentroidQueue.size() > 0) {
313+
nextParentDistance = parentCentroidQueue.topScore();
314+
} else {
315+
nextParentDistance = Float.MAX_VALUE;
316+
}
317+
centroidDistance = centroidQueue.topScore();
318+
}
319+
329320
int centroidOrdinal = centroidQueue.pop();
330-
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
331-
// is enough?
321+
// TODO need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing is enough?
332322
expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal));
333323
actualDocs += scorer.visit(knnCollector);
334324
}
@@ -346,6 +336,28 @@ public final void search(String field, float[] target, KnnCollector knnCollector
346336
}
347337
}
348338

339+
private static void updateCentroidQueueWNextParent(
340+
CentroidWChildrenQueryScorer parentCentroidQueryScorer,
341+
NeighborQueue parentCentroidQueue,
342+
CentroidQueryScorer centroidQueryScorer,
343+
NeighborQueue centroidQueue
344+
) throws IOException {
345+
int parentCentroidOrdinal = parentCentroidQueue.pop();
346+
347+
int childCentroidOrdinal;
348+
int childCentroidCount;
349+
if (parentCentroidOrdinal == -1) {
350+
// score all centroids
351+
childCentroidOrdinal = 0;
352+
childCentroidCount = centroidQueryScorer.size();
353+
} else {
354+
childCentroidOrdinal = parentCentroidQueryScorer.getChildCentroidStart(parentCentroidOrdinal);
355+
childCentroidCount = parentCentroidQueryScorer.getChildCount(parentCentroidOrdinal);
356+
}
357+
// TODO: add back scorePostingLists? seems like it's not doing anything at this point
358+
centroidQueryScorer.bulkScore(centroidQueue, childCentroidOrdinal, childCentroidOrdinal + childCentroidCount);
359+
}
360+
349361
@Override
350362
public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException {
351363
final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field);
@@ -409,8 +421,6 @@ interface CentroidQueryScorer {
409421

410422
float[] centroid(int centroidOrdinal) throws IOException;
411423

412-
void bulkScore(NeighborQueue queue) throws IOException;
413-
414424
void bulkScore(NeighborQueue queue, int start, int end) throws IOException;
415425
}
416426

0 commit comments

Comments
 (0)