Skip to content

Commit 1a03e5b

Browse files
k-rusdriftx
authored andcommitted
CNDB-13553 aggregate document frequencies on entire node (#1802)
Aggregates document frequencies for each query term for entire node before passing to calculate BM25 scores. This removes unbalanced and inconsistent results between different splits to SSTables and MemTables. As the result the test, which was demonstrating this inconsistency, is fixed to provide the same result. The same result is moved into shared constants. For SSTables created with older version total count is not stored and thus the document frequencies are not aggregated on those SSTables. Thus, the old way of calculating the average document frequencies per segment is used. This also creates difference in the test result for EC vs newer versions. Document frequencies are calculated by searching the query terms, which gives more reliable statistics than using terms distribution histogram, but slower. Another possible approach is to get statistics from posting lists. Also search API method of MemtableIndex is cleaned from unused argument limit. Some unused imports were removed in relevant files. Fixes riptano/cndb#13553
1 parent c6a2d79 commit 1a03e5b

20 files changed

+312
-232
lines changed

src/java/org/apache/cassandra/index/sai/IndexContext.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -479,12 +479,12 @@ private KeyRangeIterator getNonEqIterator(QueryContext context, Collection<Memta
479479
else
480480
{
481481
Expression negExpression = expression.negated();
482-
KeyRangeIterator matchedKeys = searchMemtable(context, memtables, negExpression, keyRange, Integer.MAX_VALUE);
482+
KeyRangeIterator matchedKeys = searchMemtable(context, memtables, negExpression, keyRange);
483483
return KeyRangeAntiJoinIterator.create(allKeys, matchedKeys);
484484
}
485485
}
486486

487-
public KeyRangeIterator searchMemtable(QueryContext context, Collection<MemtableIndex> memtables, Expression expression, AbstractBounds<PartitionPosition> keyRange, int limit)
487+
public KeyRangeIterator searchMemtable(QueryContext context, Collection<MemtableIndex> memtables, Expression expression, AbstractBounds<PartitionPosition> keyRange)
488488
{
489489
if (expression.getOp().isNonEquality())
490490
{
@@ -501,9 +501,7 @@ public KeyRangeIterator searchMemtable(QueryContext context, Collection<Memtable
501501
try
502502
{
503503
for (MemtableIndex index : memtables)
504-
{
505-
builder.add(index.search(context, expression, keyRange, limit));
506-
}
504+
builder.add(index.search(context, expression, keyRange));
507505

508506
return builder.build();
509507
}

src/java/org/apache/cassandra/index/sai/SSTableIndex.java

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import org.apache.cassandra.index.sai.iterators.KeyRangeIterator;
4848
import org.apache.cassandra.index.sai.plan.Expression;
4949
import org.apache.cassandra.index.sai.plan.Orderer;
50+
import org.apache.cassandra.index.sai.utils.AbortedOperationException;
5051
import org.apache.cassandra.index.sai.utils.PrimaryKey;
5152
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
5253
import org.apache.cassandra.index.sai.utils.TypeUtil;
@@ -55,6 +56,7 @@
5556
import org.apache.cassandra.io.sstable.format.SSTableReader;
5657
import org.apache.cassandra.io.util.FileUtils;
5758
import org.apache.cassandra.utils.CloseableIterator;
59+
import org.apache.cassandra.utils.Throwables;
5860

5961
/**
6062
* SSTableIndex is created for each column index on individual sstable to track per-column indexer.
@@ -145,11 +147,41 @@ public long getApproximateTermCount()
145147
return searchableIndex.getApproximateTermCount();
146148
}
147149

150+
/**
151+
* Estimates the number of rows that would be returned by this index given the predicate using the index
152+
* histogram.
153+
* Note that this is not a guarantee of the number of rows that will actually be returned.
154+
*
155+
* @return an approximate number of the matching rows
156+
*/
148157
public long estimateMatchingRowsCount(Expression predicate, AbstractBounds<PartitionPosition> keyRange)
149158
{
150159
return searchableIndex.estimateMatchingRowsCount(predicate, keyRange);
151160
}
152161

162+
/**
163+
* Counts the number of rows that would be returned by this index given the predicate.
164+
*
165+
* @return the row count
166+
*/
167+
public long getMatchingRowsCount(Expression predicate, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext)
168+
{
169+
queryContext.checkpoint();
170+
queryContext.addSstablesHit(1);
171+
assert !isReleased();
172+
173+
try (KeyRangeIterator keyIterator = search(predicate, keyRange, queryContext, false))
174+
{
175+
return keyIterator.getMaxKeys();
176+
}
177+
catch (Throwable e)
178+
{
179+
if (logger.isDebugEnabled() && !(e instanceof AbortedOperationException))
180+
logger.debug(String.format("Failed search an index %s.", getSSTable()), e);
181+
throw Throwables.cleaned(e);
182+
}
183+
}
184+
153185
/**
154186
* @return total size of per-column SAI components, in bytes
155187
*/
@@ -228,23 +260,22 @@ private KeyRangeIterator getNonEqIterator(Expression expression,
228260
else
229261
{
230262
Expression negExpression = expression.negated();
231-
KeyRangeIterator matchedKeys = searchableIndex.search(negExpression, keyRange, context, defer, Integer.MAX_VALUE);
263+
KeyRangeIterator matchedKeys = searchableIndex.search(negExpression, keyRange, context, defer);
232264
return KeyRangeAntiJoinIterator.create(allKeys, matchedKeys);
233265
}
234266
}
235267

236268
public KeyRangeIterator search(Expression expression,
237269
AbstractBounds<PartitionPosition> keyRange,
238270
QueryContext context,
239-
boolean defer,
240-
int limit) throws IOException
271+
boolean defer) throws IOException
241272
{
242273
if (expression.getOp().isNonEquality())
243274
{
244275
return getNonEqIterator(expression, keyRange, context, defer);
245276
}
246277

247-
return searchableIndex.search(expression, keyRange, context, defer, limit);
278+
return searchableIndex.search(expression, keyRange, context, defer);
248279
}
249280

250281
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(Orderer orderer,

src/java/org/apache/cassandra/index/sai/disk/EmptyIndex.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ public DecoratedKey maxKey()
9696
public KeyRangeIterator search(Expression expression,
9797
AbstractBounds<PartitionPosition> keyRange,
9898
QueryContext context,
99-
boolean defer,
100-
int limit) throws IOException
99+
boolean defer) throws IOException
101100
{
102101
return KeyRangeIterator.empty();
103102
}

src/java/org/apache/cassandra/index/sai/disk/SearchableIndex.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ public interface SearchableIndex extends Closeable
6969
KeyRangeIterator search(Expression expression,
7070
AbstractBounds<PartitionPosition> keyRange,
7171
QueryContext context,
72-
boolean defer, int limit) throws IOException;
72+
boolean defer) throws IOException;
7373

7474
List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(Orderer orderer,
7575
Expression slice,

src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java

Lines changed: 11 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import java.io.UncheckedIOException;
2323
import java.lang.invoke.MethodHandles;
2424
import java.nio.ByteBuffer;
25-
import java.util.HashMap;
2625
import java.util.List;
2726
import java.util.Map;
2827
import java.util.Objects;
@@ -34,7 +33,6 @@
3433
import org.slf4j.Logger;
3534
import org.slf4j.LoggerFactory;
3635

37-
import org.apache.cassandra.cql3.Operator;
3836
import org.apache.cassandra.db.PartitionPosition;
3937
import org.apache.cassandra.db.Slice;
4038
import org.apache.cassandra.db.Slices;
@@ -75,6 +73,7 @@
7573
import org.apache.cassandra.utils.bytecomparable.ByteComparable;
7674

7775
import static org.apache.cassandra.index.sai.disk.PostingList.END_OF_STREAM;
76+
import static org.apache.cassandra.index.sai.disk.v1.SegmentMetadata.INVALID_TOTAL_TERM_COUNT;
7877

7978
/**
8079
* Executes {@link Expression}s against the trie-based terms dictionary for an individual index segment.
@@ -204,8 +203,6 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
204203
var postings = reader.exactMatch(encodedTerm, listener, queryContext);
205204
return postings == null ? PostingList.EMPTY : postings;
206205
}));
207-
// extract the match count for each
208-
var documentFrequencies = postingLists.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (long) e.getValue().size()));
209206

210207
var pkm = primaryKeyMapFactory.newPerSSTablePrimaryKeyMap();
211208
var merged = IntersectingPostingList.intersect(postingLists);
@@ -244,25 +241,12 @@ public void close()
244241
FileUtils.closeQuietly(pkm, merged, docLengthsReader);
245242
}
246243
};
247-
return bm25Internal(it, queryTerms, documentFrequencies, orderer.bm25Stats);
248-
}
249-
250-
private CloseableIterator<PrimaryKeyWithSortKey> bm25Internal(CloseableIterator<BM25Utils.DocTF> keyIterator,
251-
List<ByteBuffer> queryTerms,
252-
Map<ByteBuffer, Long> documentFrequencies,
253-
BM25Utils.AggDocsStats aggStats)
254-
{
255-
long totalRows = sstable.getTotalRows();
256-
// since doc frequencies can be an estimate from the index histogram, which does not have bounded error,
257-
// cap frequencies to total rows so that the IDF term doesn't turn negative
258-
Map<ByteBuffer, Long> cappedFrequencies = documentFrequencies.entrySet().stream()
259-
.collect(Collectors.toMap(Map.Entry::getKey, e -> Math.min(e.getValue(), totalRows)));
260-
BM25Utils.DocStats docStats = new BM25Utils.DocStats(cappedFrequencies, aggStats);
261-
return BM25Utils.computeScores(keyIterator,
244+
return BM25Utils.computeScores(it,
262245
queryTerms,
263-
docStats,
246+
orderer.bm25stats,
264247
indexContext,
265-
sstable.descriptor.id);
248+
sstable.descriptor.id,
249+
metadata.totalTermCount == INVALID_TOTAL_TERM_COUNT);
266250
}
267251

268252
@Override
@@ -278,21 +262,17 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
278262
}
279263

280264
var queryTerms = orderer.getQueryTerms();
281-
// compute documentFrequencies from either histogram or an index search
282-
var documentFrequencies = new HashMap<ByteBuffer, Long>();
283-
// any index new enough to support BM25 should also support histograms
284-
assert metadata.version.onDiskFormat().indexFeatureSet().hasTermsHistogram();
285-
for (ByteBuffer term : queryTerms)
286-
{
287-
long matches = metadata.estimateNumRowsMatching(new Expression(indexContext).add(Operator.ANALYZER_MATCHES, term));
288-
documentFrequencies.put(term, matches);
289-
}
290265
var analyzer = indexContext.getAnalyzerFactory().create();
291266
var it = keys.stream()
292267
.map(pk -> EagerDocTF.createFromDocument(pk, readColumn(sstable, pk), analyzer, queryTerms))
293268
.filter(Objects::nonNull)
294269
.iterator();
295-
return bm25Internal(CloseableIterator.wrap(it), queryTerms, documentFrequencies, orderer.bm25Stats);
270+
return BM25Utils.computeScores(CloseableIterator.wrap(it),
271+
queryTerms,
272+
orderer.bm25stats,
273+
indexContext,
274+
sstable.descriptor.id,
275+
metadata.totalTermCount == INVALID_TOTAL_TERM_COUNT);
296276
}
297277

298278
@Override

src/java/org/apache/cassandra/index/sai/disk/v1/Segment.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,9 @@ public long indexFileCacheSize()
139139
* @param keyRange key range specific in read command, used by ANN index
140140
* @param context to track per sstable cache and per query metrics
141141
* @param defer create the iterator in a deferred state
142-
* @param limit the num of rows to returned, used by ANN index
143142
* @return range iterator of {@link PrimaryKey} that matches given expression
144143
*/
145-
public KeyRangeIterator search(Expression expression, AbstractBounds<PartitionPosition> keyRange, QueryContext context, boolean defer, int limit) throws IOException
144+
public KeyRangeIterator search(Expression expression, AbstractBounds<PartitionPosition> keyRange, QueryContext context, boolean defer) throws IOException
146145
{
147146
return index.search(expression, keyRange, context, defer);
148147
}

src/java/org/apache/cassandra/index/sai/disk/v1/V1SearchableIndex.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.apache.cassandra.utils.CloseableIterator;
4949
import org.apache.cassandra.utils.Throwables;
5050

51+
import static org.apache.cassandra.index.sai.disk.v1.SegmentMetadata.INVALID_TOTAL_TERM_COUNT;
5152
import static org.apache.cassandra.index.sai.virtual.SegmentsSystemView.CELL_COUNT;
5253
import static org.apache.cassandra.index.sai.virtual.SegmentsSystemView.COLUMN_NAME;
5354
import static org.apache.cassandra.index.sai.virtual.SegmentsSystemView.COMPONENT_METADATA;
@@ -90,10 +91,13 @@ public V1SearchableIndex(SSTableContext sstableContext, IndexComponents.ForRead
9091

9192
metadatas = SegmentMetadata.load(source, indexContext, sstableContext);
9293

94+
long termCount = 0;
9395
for (SegmentMetadata metadata : metadatas)
9496
{
9597
segmentsBuilder.add(new Segment(indexContext, sstableContext, indexFiles, metadata));
98+
termCount += metadata.totalTermCount == INVALID_TOTAL_TERM_COUNT ? 0 : metadata.totalTermCount;
9699
}
100+
this.approximateTermCount = termCount;
97101

98102
segments = segmentsBuilder.build();
99103
assert !segments.isEmpty();
@@ -106,7 +110,6 @@ public V1SearchableIndex(SSTableContext sstableContext, IndexComponents.ForRead
106110
this.maxTerm = metadatas.stream().map(m -> m.maxTerm).max(TypeUtil.comparator(indexContext.getValidator(), version)).orElse(null);
107111

108112
this.numRows = metadatas.stream().mapToLong(m -> m.numRows).sum();
109-
this.approximateTermCount = metadatas.stream().mapToLong(m -> m.totalTermCount).sum();
110113

111114
this.minSSTableRowId = metadatas.get(0).minSSTableRowId;
112115
this.maxSSTableRowId = metadatas.get(metadatas.size() - 1).maxSSTableRowId;
@@ -177,8 +180,7 @@ public DecoratedKey maxKey()
177180
public KeyRangeIterator search(Expression expression,
178181
AbstractBounds<PartitionPosition> keyRange,
179182
QueryContext context,
180-
boolean defer,
181-
int limit) throws IOException
183+
boolean defer) throws IOException
182184
{
183185
KeyRangeConcatIterator.Builder rangeConcatIteratorBuilder = KeyRangeConcatIterator.builder(segments.size());
184186

@@ -188,7 +190,7 @@ public KeyRangeIterator search(Expression expression,
188190
{
189191
if (segment.intersects(keyRange))
190192
{
191-
rangeConcatIteratorBuilder.add(segment.search(expression, keyRange, context, defer, limit));
193+
rangeConcatIteratorBuilder.add(segment.search(expression, keyRange, context, defer));
192194
}
193195
}
194196

src/java/org/apache/cassandra/index/sai/disk/vector/VectorMemtableIndex.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ else if (primaryKey.compareTo(maximumKey) > 0)
195195
}
196196

197197
@Override
198-
public KeyRangeIterator search(QueryContext context, Expression expr, AbstractBounds<PartitionPosition> keyRange, int limit)
198+
public KeyRangeIterator search(QueryContext context, Expression expr, AbstractBounds<PartitionPosition> keyRange)
199199
{
200200
if (expr.getOp() != Expression.Op.BOUNDED_ANN)
201201
throw new IllegalArgumentException(indexContext.logMessage("Only BOUNDED_ANN is supported, received: " + expr));
@@ -217,12 +217,18 @@ public KeyRangeIterator search(QueryContext context, Expression expr, AbstractBo
217217
}
218218

219219
@Override
220-
public long estimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange)
220+
public long estimateMatchingRowsCountUsingFirstShard(Expression expression, AbstractBounds<PartitionPosition> keyRange)
221221
{
222222
// For BOUNDED_ANN we use the old way of estimating cardinality - by running the search.
223223
throw new UnsupportedOperationException("Cardinality estimation not supported by vector indexes");
224224
}
225225

226+
@Override
227+
public long estimateMatchingRowsCountUsingAllShards(Expression expression, AbstractBounds<PartitionPosition> keyRange)
228+
{
229+
throw new UnsupportedOperationException("Cardinality estimation not supported by vector indexes");
230+
}
231+
226232
@Override
227233
public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext context,
228234
Orderer orderer,

src/java/org/apache/cassandra/index/sai/iterators/KeyRangeTermIterator.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ private KeyRangeTermIterator(KeyRangeIterator union, Set<SSTableIndex> reference
6666

6767

6868
@SuppressWarnings("resource")
69-
public static KeyRangeTermIterator build(final Expression e, QueryView view, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, boolean defer, int limit)
69+
public static KeyRangeTermIterator build(final Expression e, QueryView view, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, boolean defer)
7070
{
71-
KeyRangeIterator rangeIterator = buildRangeIterator(e, view, keyRange, queryContext, defer, limit);
71+
KeyRangeIterator rangeIterator = buildRangeIterator(e, view, keyRange, queryContext, defer);
7272
return new KeyRangeTermIterator(rangeIterator, view.sstableIndexes, queryContext);
7373
}
7474

75-
private static KeyRangeIterator buildRangeIterator(final Expression e, QueryView view, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, boolean defer, int limit)
75+
private static KeyRangeIterator buildRangeIterator(final Expression e, QueryView view, AbstractBounds<PartitionPosition> keyRange, QueryContext queryContext, boolean defer)
7676
{
7777
final List<KeyRangeIterator> tokens = new ArrayList<>(1 + view.sstableIndexes.size());
7878

79-
KeyRangeIterator memtableIterator = e.context.searchMemtable(queryContext, view.memtableIndexes, e, keyRange, limit);
79+
KeyRangeIterator memtableIterator = e.context.searchMemtable(queryContext, view.memtableIndexes, e, keyRange);
8080
if (memtableIterator != null)
8181
tokens.add(memtableIterator);
8282

@@ -88,7 +88,7 @@ private static KeyRangeIterator buildRangeIterator(final Expression e, QueryView
8888
queryContext.addSstablesHit(1);
8989
assert !index.isReleased();
9090

91-
KeyRangeIterator keyIterator = index.search(e, keyRange, queryContext, defer, limit);
91+
KeyRangeIterator keyIterator = index.search(e, keyRange, queryContext, defer);
9292

9393
if (keyIterator == null || !keyIterator.hasNext())
9494
continue;

src/java/org/apache/cassandra/index/sai/memory/MemtableIndex.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,29 @@ public interface MemtableIndex extends MemtableOrdering
7373
void update(DecoratedKey key, Clustering clustering, ByteBuffer oldValue, ByteBuffer newValue, Memtable memtable, OpOrder.Group opGroup);
7474
void update(DecoratedKey key, Clustering clustering, Iterator<ByteBuffer> oldValues, Iterator<ByteBuffer> newValues, Memtable memtable, OpOrder.Group opGroup);
7575

76-
KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds<PartitionPosition> keyRange, int limit);
76+
KeyRangeIterator search(QueryContext queryContext, Expression expression, AbstractBounds<PartitionPosition> keyRange);
7777

78-
long estimateMatchingRowsCount(Expression expression, AbstractBounds<PartitionPosition> keyRange);
78+
/**
79+
* Estimates the number of rows that would be returned by this index given the predicate.
80+
* It is extrapolated from the first shard.
81+
* Note that this is not a guarantee of the number of rows that will actually be returned.
82+
*
83+
* @param expression predicate to match
84+
* @param keyRange the key range to search within
85+
* @return an approximate number of the matching rows
86+
*/
87+
long estimateMatchingRowsCountUsingFirstShard(Expression expression, AbstractBounds<PartitionPosition> keyRange);
88+
89+
/**
90+
* Estimates the number of rows that would be returned by this index given the predicate.
91+
* It estimates from all relevant shards individually.
92+
* Note that this is not a guarantee of the number of rows that will actually be returned.
93+
*
94+
* @param expression predicate to match
95+
* @param keyRange the key range to search within
96+
* @return an estimated number of the matching rows
97+
*/
98+
long estimateMatchingRowsCountUsingAllShards(Expression expression, AbstractBounds<PartitionPosition> keyRange);
7999

80100
Iterator<Pair<ByteComparable.Preencoded, List<MemoryIndex.PkWithFrequency>>> iterator(DecoratedKey min, DecoratedKey max);
81101

0 commit comments

Comments
 (0)