Skip to content

Commit 28f611e

Browse files
committed
Merge remote-tracking branch 'elastic/main' into exphisto-null-sum
2 parents ff26321 + c8cc2e0 commit 28f611e

File tree

6 files changed

+102
-47
lines changed

6 files changed

+102
-47
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java

Lines changed: 67 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,14 @@ private static CentroidIterator getCentroidIteratorNoParent(
272272
FixedBitSet acceptCentroids
273273
) throws IOException {
274274
final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true);
275+
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
275276
score(
276277
neighborQueue,
277278
numCentroids,
278279
0,
279280
scorer,
281+
centroids,
282+
centroidQuantizeSize,
280283
quantizeQuery,
281284
queryParams,
282285
globalCentroidDp,
@@ -315,26 +318,41 @@ private static CentroidIterator getCentroidIteratorWithParents(
315318
FixedBitSet acceptCentroids
316319
) throws IOException {
317320
// build the three queues we are going to use
321+
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
318322
final NeighborQueue parentsQueue = new NeighborQueue(numParents, true);
319323
final int maxChildrenSize = centroids.readVInt();
320324
final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true);
321325
final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids);
322-
final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true);
323-
// score the parents
326+
final int numCentroidsFiltered = acceptCentroids == null ? numCentroids : acceptCentroids.cardinality();
324327
final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE];
325-
score(
326-
parentsQueue,
327-
numParents,
328-
0,
329-
scorer,
330-
quantizeQuery,
331-
queryParams,
332-
globalCentroidDp,
333-
fieldInfo.getVectorSimilarityFunction(),
334-
scores,
335-
null
336-
);
337-
final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES;
328+
final NeighborQueue neighborQueue;
329+
if (acceptCentroids != null && numCentroidsFiltered <= bufferSize) {
330+
// we are collecting every non-filter centroid, therefore we do not need to score the
331+
// parents. We give each of them the same score.
332+
neighborQueue = new NeighborQueue(numCentroidsFiltered, true);
333+
for (int i = 0; i < numParents; i++) {
334+
parentsQueue.add(i, 0.5f);
335+
}
336+
centroids.skipBytes(centroidQuantizeSize * numParents);
337+
} else {
338+
neighborQueue = new NeighborQueue(bufferSize, true);
339+
// score the parents
340+
score(
341+
parentsQueue,
342+
numParents,
343+
0,
344+
scorer,
345+
centroids,
346+
centroidQuantizeSize,
347+
quantizeQuery,
348+
queryParams,
349+
globalCentroidDp,
350+
fieldInfo.getVectorSimilarityFunction(),
351+
scores,
352+
null
353+
);
354+
}
355+
338356
final long offset = centroids.getFilePointer();
339357
final long childrenOffset = offset + (long) Long.BYTES * numParents;
340358
// populate the children's queue by reading parents one by one
@@ -429,6 +447,8 @@ private static void populateOneChildrenGroup(
429447
numChildren,
430448
childrenOrdinal,
431449
scorer,
450+
centroids,
451+
centroidQuantizeSize,
432452
quantizeQuery,
433453
queryParams,
434454
globalCentroidDp,
@@ -443,48 +463,56 @@ private static void score(
443463
int size,
444464
int scoresOffset,
445465
ES92Int7VectorsScorer scorer,
466+
IndexInput centroids,
467+
long centroidQuantizeSize,
446468
byte[] quantizeQuery,
447469
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
448470
float centroidDp,
449471
VectorSimilarityFunction similarityFunction,
450472
float[] scores,
451473
FixedBitSet acceptCentroids
452474
) throws IOException {
453-
// TODO: if accept centroids is not null, we can save some vector ops here
454475
int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1;
455476
int i = 0;
456477
for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) {
457-
scorer.scoreBulk(
458-
quantizeQuery,
459-
queryCorrections.lowerInterval(),
460-
queryCorrections.upperInterval(),
461-
queryCorrections.quantizedComponentSum(),
462-
queryCorrections.additionalCorrection(),
463-
similarityFunction,
464-
centroidDp,
465-
scores
466-
);
467-
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
468-
int centroidOrd = scoresOffset + i + j;
469-
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
470-
neighborQueue.add(centroidOrd, scores[j]);
478+
if (acceptCentroids == null
479+
|| acceptCentroids.cardinality(scoresOffset + i, scoresOffset + i + ES92Int7VectorsScorer.BULK_SIZE) > 0) {
480+
scorer.scoreBulk(
481+
quantizeQuery,
482+
queryCorrections.lowerInterval(),
483+
queryCorrections.upperInterval(),
484+
queryCorrections.quantizedComponentSum(),
485+
queryCorrections.additionalCorrection(),
486+
similarityFunction,
487+
centroidDp,
488+
scores
489+
);
490+
for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) {
491+
int centroidOrd = scoresOffset + i + j;
492+
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
493+
neighborQueue.add(centroidOrd, scores[j]);
494+
}
471495
}
496+
} else {
497+
centroids.skipBytes(ES92Int7VectorsScorer.BULK_SIZE * centroidQuantizeSize);
472498
}
473499
}
474500

475501
for (; i < size; i++) {
476-
float score = scorer.score(
477-
quantizeQuery,
478-
queryCorrections.lowerInterval(),
479-
queryCorrections.upperInterval(),
480-
queryCorrections.quantizedComponentSum(),
481-
queryCorrections.additionalCorrection(),
482-
similarityFunction,
483-
centroidDp
484-
);
485502
int centroidOrd = scoresOffset + i;
486503
if (acceptCentroids == null || acceptCentroids.get(centroidOrd)) {
504+
float score = scorer.score(
505+
quantizeQuery,
506+
queryCorrections.lowerInterval(),
507+
queryCorrections.upperInterval(),
508+
queryCorrections.quantizedComponentSum(),
509+
queryCorrections.additionalCorrection(),
510+
similarityFunction,
511+
centroidDp
512+
);
487513
neighborQueue.add(centroidOrd, score);
514+
} else {
515+
centroids.skipBytes(centroidQuantizeSize);
488516
}
489517
}
490518
}

server/src/main/java/org/elasticsearch/search/vectors/IVFKnnSearchStrategy.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public class IVFKnnSearchStrategy extends KnnSearchStrategy {
1919
private final SetOnce<AbstractMaxScoreKnnCollector> collector = new SetOnce<>();
2020
private final LongAccumulator accumulator;
2121

22-
IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) {
22+
public IVFKnnSearchStrategy(float visitRatio, LongAccumulator accumulator) {
2323
this.visitRatio = visitRatio;
2424
this.accumulator = accumulator;
2525
}

server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@
2929
import org.apache.lucene.index.VectorEncoding;
3030
import org.apache.lucene.index.VectorSimilarityFunction;
3131
import org.apache.lucene.search.AcceptDocs;
32+
import org.apache.lucene.search.KnnCollector;
3233
import org.apache.lucene.search.TopDocs;
34+
import org.apache.lucene.search.TopKnnCollector;
3335
import org.apache.lucene.store.Directory;
3436
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
3537
import org.apache.lucene.tests.util.TestUtil;
3638
import org.apache.lucene.util.BytesRef;
3739
import org.elasticsearch.common.logging.LogConfigurator;
40+
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3841
import org.junit.Before;
3942

4043
import java.io.IOException;
@@ -353,17 +356,27 @@ private void doRestrictiveFilter(boolean dense) throws IOException {
353356
LeafReader leafReader = getOnlyLeafReader(reader);
354357
float[] vector = randomVector(dimensions);
355358
// we might collect the same document twice because of soar assignments
356-
TopDocs topDocs = leafReader.searchNearestVectors(
359+
KnnCollector collector;
360+
if (random().nextBoolean()) {
361+
collector = new TopKnnCollector(random().nextInt(2 * matchingDocs, 3 * matchingDocs), Integer.MAX_VALUE);
362+
} else {
363+
collector = new TopKnnCollector(
364+
random().nextInt(2 * matchingDocs, 3 * matchingDocs),
365+
Integer.MAX_VALUE,
366+
new IVFKnnSearchStrategy(0.25f, null)
367+
);
368+
}
369+
leafReader.searchNearestVectors(
357370
"f",
358371
vector,
359-
random().nextInt(2 * matchingDocs, 3 * matchingDocs),
372+
collector,
360373
AcceptDocs.fromIteratorSupplier(
361374
() -> leafReader.postings(new Term("k", new BytesRef("A"))),
362375
leafReader.getLiveDocs(),
363376
leafReader.maxDoc()
364-
),
365-
Integer.MAX_VALUE
377+
)
366378
);
379+
TopDocs topDocs = collector.topDocs();
367380
Set<Integer> uniqueDocIds = new HashSet<>();
368381
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
369382
uniqueDocIds.add(topDocs.scoreDocs[i].doc);

x-pack/plugin/esql/qa/testFixtures/src/main/resources/exponential_histogram.csv-spec

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,19 @@ NULL | NULL | NULL
8080
;
8181

8282

83+
histoAsCaseValue
84+
required_capability: exponential_histogram_pre_tech_preview_v1
85+
86+
FROM exp_histo_sample
87+
| INLINE STATS p50 = PERCENTILE(responseTime, 50) BY instance, @timestamp
88+
| EVAL filteredHisto = CASE(p50 > 0.1, responseTime)
89+
| INLINE STATS filteredMax = MAX(filteredHisto) BY instance, @timestamp //MAX is null if the histogram is null
90+
| STATS filteredCount = COUNT(filteredMax)
91+
;
92+
93+
filteredCount:long
94+
3
95+
;
8396

8497
ungroupedPercentiles
8598
required_capability: exponential_histogram_pre_tech_preview_v1

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/Case.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ ConditionEvaluatorSupplier toEvaluator(ToEvaluator toEvaluator) {
8181
"keyword",
8282
"long",
8383
"unsigned_long",
84-
"version" },
84+
"version",
85+
"exponential_histogram" },
8586
description = """
8687
Accepts pairs of conditions and values. The function returns the value that
8788
belongs to the first condition that evaluates to `true`.
@@ -126,7 +127,8 @@ public Case(
126127
"long",
127128
"text",
128129
"unsigned_long",
129-
"version" },
130+
"version",
131+
"exponential_histogram" },
130132
description = "The value that’s returned when the corresponding condition is the first to evaluate to `true`. "
131133
+ "The default value is returned when no condition matches."
132134
) List<Expression> rest

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/conditional/CaseTests.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ public class CaseTests extends AbstractScalarFunctionTestCase {
6666
if (Build.current().isSnapshot()) {
6767
t.addAll(
6868
DataType.UNDER_CONSTRUCTION.stream()
69-
.filter(type -> type != DataType.EXPONENTIAL_HISTOGRAM) // TODO(b/133393): implement
7069
.filter(type -> type != DataType.AGGREGATE_METRIC_DOUBLE && type != DataType.DENSE_VECTOR)
7170
.toList()
7271
);

0 commit comments

Comments
 (0)