Skip to content

Commit 830613e

Browse files
CNDB-13499: Optimize simple BM25 by deferring PrK creation (#1662)
### What is the issue Fixes: riptano/cndb#13499 ### What does this PR fix and why was it fixed This commit has two key optimizations. First, we defer materializing the PrimaryKey in simple BM25 queries until we know that the PrimaryKey is among the best scored rows in the sstable. This buys us two things. The most important is that we can defer reading the PrK's token from disk. The second is that we materialize one less object per row, which saves us essentially O(n) memory. Second, we defer creating the PrimaryKeyWithSortKey objects by using the jvector NodeQueue to sort based on a long packed by an index (int) and a score (float). This is a more compact way to sort because it takes less space and uses a slightly better sort algorithm for our use case since it is unlikely that we'll need to consume all of the rows being sorted. Initial testing on a 1 million document table with shows that this optimization improves query latency by about 40 percent.
1 parent e2f79c8 commit 830613e

File tree

3 files changed

+143
-47
lines changed

3 files changed

+143
-47
lines changed

src/java/org/apache/cassandra/index/sai/disk/v1/InvertedIndexSearcher.java

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.apache.cassandra.db.PartitionPosition;
3838
import org.apache.cassandra.db.Slice;
3939
import org.apache.cassandra.db.Slices;
40+
import org.apache.cassandra.db.memtable.Memtable;
4041
import org.apache.cassandra.db.rows.Cell;
4142
import org.apache.cassandra.db.rows.Row;
4243
import org.apache.cassandra.dht.AbstractBounds;
@@ -45,6 +46,7 @@
4546
import org.apache.cassandra.index.sai.QueryContext;
4647
import org.apache.cassandra.index.sai.SSTableContext;
4748
import org.apache.cassandra.index.sai.disk.PostingList;
49+
import org.apache.cassandra.index.sai.disk.PrimaryKeyMap;
4850
import org.apache.cassandra.index.sai.disk.TermsIterator;
4951
import org.apache.cassandra.index.sai.disk.format.IndexComponentType;
5052
import org.apache.cassandra.index.sai.disk.format.Version;
@@ -55,11 +57,13 @@
5557
import org.apache.cassandra.index.sai.plan.Expression;
5658
import org.apache.cassandra.index.sai.plan.Orderer;
5759
import org.apache.cassandra.index.sai.utils.BM25Utils;
58-
import org.apache.cassandra.index.sai.utils.BM25Utils.DocTF;
60+
import org.apache.cassandra.index.sai.utils.BM25Utils.EagerDocTF;
5961
import org.apache.cassandra.index.sai.utils.PrimaryKey;
62+
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithScore;
6063
import org.apache.cassandra.index.sai.utils.PrimaryKeyWithSortKey;
6164
import org.apache.cassandra.index.sai.utils.RowIdWithByteComparable;
6265
import org.apache.cassandra.index.sai.utils.SAICodecUtils;
66+
import org.apache.cassandra.io.sstable.SSTableId;
6367
import org.apache.cassandra.io.sstable.format.SSTableReader;
6468
import org.apache.cassandra.io.sstable.format.SSTableReadsListener;
6569
import org.apache.cassandra.io.util.FileHandle;
@@ -200,20 +204,23 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderBy(Orderer orderer, Express
200204
var docLengthsReader = new DocLengthsReader(docLengths, docLengthsMeta);
201205

202206
// Wrap the iterator with resource management
203-
var it = new AbstractIterator<DocTF>() { // Anonymous class extends AbstractIterator
207+
var it = new AbstractIterator<BM25Utils.DocTF>() { // Anonymous class extends AbstractIterator
204208
private boolean closed;
205209

206210
@Override
207-
protected DocTF computeNext()
211+
protected BM25Utils.DocTF computeNext()
208212
{
209213
try
210214
{
211215
int rowId = merged.nextPosting();
212216
if (rowId == PostingList.END_OF_STREAM)
213217
return endOfData();
218+
// Reads from disk.
214219
int docLength = docLengthsReader.get(rowId); // segment-local rowid
215-
var pk = pkm.primaryKeyFromRowId(segmentRowIdOffset + rowId); // sstable-global rowid
216-
return new DocTF(pk, docLength, merged.frequencies());
220+
// We defer creating the primary key because it reads the token from disk, which is only needed
221+
// for the top rows just before they are materialized from disk, so we wait until after scoring
222+
// and sorting to read the token.
223+
return new LazyDocTF(pkm, segmentRowIdOffset + rowId, docLength, merged.frequencies());
217224
}
218225
catch (IOException e)
219226
{
@@ -232,7 +239,7 @@ public void close()
232239
return bm25Internal(it, queryTerms, documentFrequencies);
233240
}
234241

235-
private CloseableIterator<PrimaryKeyWithSortKey> bm25Internal(CloseableIterator<DocTF> keyIterator,
242+
private CloseableIterator<PrimaryKeyWithSortKey> bm25Internal(CloseableIterator<BM25Utils.DocTF> keyIterator,
236243
List<ByteBuffer> queryTerms,
237244
Map<ByteBuffer, Long> documentFrequencies)
238245
{
@@ -269,7 +276,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(SSTableReader rea
269276
}
270277
var analyzer = indexContext.getAnalyzerFactory().create();
271278
var it = keys.stream()
272-
.map(pk -> DocTF.createFromDocument(pk, readColumn(sstable, pk), analyzer, queryTerms))
279+
.map(pk -> EagerDocTF.createFromDocument(pk, readColumn(sstable, pk), analyzer, queryTerms))
273280
.filter(Objects::nonNull)
274281
.iterator();
275282
return bm25Internal(CloseableIterator.wrap(it), queryTerms, documentFrequencies);
@@ -334,4 +341,50 @@ public void close()
334341
FileUtils.closeQuietly(source, currentPostingList);
335342
}
336343
}
344+
345+
/**
346+
* A {@link BM25Utils.DocTF} that is lazy in that it does not create the {@link PrimaryKey} until it is required.
347+
*/
348+
private static class LazyDocTF implements BM25Utils.DocTF
349+
{
350+
private final PrimaryKeyMap pkm;
351+
private final long sstableRowId;
352+
private final int docLength;
353+
private final Map<ByteBuffer, Integer> frequencies;
354+
355+
LazyDocTF(PrimaryKeyMap pkm, long sstableRowId, int docLength, Map<ByteBuffer, Integer> frequencies)
356+
{
357+
this.pkm = pkm;
358+
this.sstableRowId = sstableRowId;
359+
this.docLength = docLength;
360+
this.frequencies = frequencies;
361+
}
362+
363+
@Override
364+
public int getTermFrequency(ByteBuffer term)
365+
{
366+
return frequencies.getOrDefault(term, 0);
367+
}
368+
369+
@Override
370+
public int termCount()
371+
{
372+
return docLength;
373+
}
374+
375+
@Override
376+
public PrimaryKeyWithSortKey primaryKey(IndexContext context, Memtable source, float score)
377+
{
378+
// Only sstables use this class, so this should never be called
379+
throw new UnsupportedOperationException();
380+
}
381+
382+
@Override
383+
public PrimaryKeyWithSortKey primaryKey(IndexContext context, SSTableId<?> source, float score)
384+
{
385+
// We can eagerly get the token now, even though it might not technically be required until we know
386+
// we have the best score. (Perhaps this should be lazy too?)
387+
return new PrimaryKeyWithScore(context, source, pkm.primaryKeyFromRowId(sstableRowId), score);
388+
}
389+
}
337390
}

src/java/org/apache/cassandra/index/sai/memory/TrieMemtableIndex.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ public List<CloseableIterator<PrimaryKeyWithSortKey>> orderBy(QueryContext query
324324
var docStats = computeDocumentFrequencies(queryContext, queryTerms);
325325
var analyzer = indexContext.getAnalyzerFactory().create();
326326
var it = Streams.stream(intersectedIterator)
327-
.map(pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms))
327+
.map(pk -> BM25Utils.EagerDocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms))
328328
.filter(Objects::nonNull)
329329
.iterator();
330330

@@ -393,7 +393,7 @@ public CloseableIterator<PrimaryKeyWithSortKey> orderResultsBy(QueryContext quer
393393
var queryTerms = orderer.getQueryTerms();
394394
var docStats = computeDocumentFrequencies(queryContext, queryTerms);
395395
var it = keys.stream()
396-
.map(pk -> BM25Utils.DocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms))
396+
.map(pk -> BM25Utils.EagerDocTF.createFromDocument(pk, getCellForKey(pk), analyzer, queryTerms))
397397
.filter(Objects::nonNull)
398398
.iterator();
399399
return BM25Utils.computeScores(CloseableIterator.wrap(it),

src/java/org/apache/cassandra/index/sai/utils/BM25Utils.java

Lines changed: 81 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@
2121
import java.nio.ByteBuffer;
2222
import java.util.ArrayList;
2323
import java.util.Collection;
24-
import java.util.Collections;
2524
import java.util.HashMap;
26-
import java.util.Iterator;
2725
import java.util.List;
2826
import java.util.Map;
2927

3028
import javax.annotation.Nullable;
3129

30+
import io.github.jbellis.jvector.graph.NodeQueue;
31+
import io.github.jbellis.jvector.util.BoundedLongHeap;
3232
import org.apache.cassandra.db.memtable.Memtable;
3333
import org.apache.cassandra.db.rows.Cell;
3434
import org.apache.cassandra.index.sai.IndexContext;
3535
import org.apache.cassandra.index.sai.analyzer.AbstractAnalyzer;
3636
import org.apache.cassandra.io.sstable.SSTableId;
3737
import org.apache.cassandra.io.util.FileUtils;
38+
import org.apache.cassandra.utils.AbstractIterator;
3839
import org.apache.cassandra.utils.CloseableIterator;
3940

4041
public class BM25Utils
@@ -60,15 +61,28 @@ public DocStats(Map<ByteBuffer, Long> frequencies, long docCount)
6061
}
6162

6263
/**
63-
* Term frequencies within a single document. All instances of a term are counted.
64+
* Term frequencies within a single document. All instances of a term are counted. Allows us to optimize for
65+
* the sstable use case, which is able to skip some reads from disk as well as some memory allocations.
6466
*/
65-
public static class DocTF
67+
public interface DocTF
68+
{
69+
int getTermFrequency(ByteBuffer term);
70+
int termCount();
71+
PrimaryKeyWithSortKey primaryKey(IndexContext context, Memtable source, float score);
72+
PrimaryKeyWithSortKey primaryKey(IndexContext context, SSTableId<?> source, float score);
73+
}
74+
75+
/**
76+
* Term frequencies within a single document. All instances of a term are counted. It is eager in that the
77+
* PrimaryKey is already created.
78+
*/
79+
public static class EagerDocTF implements DocTF
6680
{
6781
private final PrimaryKey pk;
6882
private final Map<ByteBuffer, Integer> frequencies;
6983
private final int termCount;
7084

71-
public DocTF(PrimaryKey pk, int termCount, Map<ByteBuffer, Integer> frequencies)
85+
public EagerDocTF(PrimaryKey pk, int termCount, Map<ByteBuffer, Integer> frequencies)
7286
{
7387
this.pk = pk;
7488
this.frequencies = frequencies;
@@ -80,6 +94,21 @@ public int getTermFrequency(ByteBuffer term)
8094
return frequencies.getOrDefault(term, 0);
8195
}
8296

97+
public int termCount()
98+
{
99+
return termCount;
100+
}
101+
102+
public PrimaryKeyWithSortKey primaryKey(IndexContext context, Memtable source, float score)
103+
{
104+
return new PrimaryKeyWithScore(context, source, pk, score);
105+
}
106+
107+
public PrimaryKeyWithSortKey primaryKey(IndexContext context, SSTableId<?> source, float score)
108+
{
109+
return new PrimaryKeyWithScore(context, source, pk, score);
110+
}
111+
83112
@Nullable
84113
public static DocTF createFromDocument(PrimaryKey pk,
85114
Cell<?> cell,
@@ -111,7 +140,7 @@ public static DocTF createFromDocument(PrimaryKey pk,
111140
if (queryTerms.size() > frequencies.size())
112141
return null;
113142

114-
return new DocTF(pk, count, frequencies);
143+
return new EagerDocTF(pk, count, frequencies);
115144
}
116145
}
117146

@@ -121,6 +150,8 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
121150
IndexContext indexContext,
122151
Object source)
123152
{
153+
assert source instanceof Memtable || source instanceof SSTableId : "Invalid source " + source.getClass();
154+
124155
// data structures for document stats and frequencies
125156
ArrayList<DocTF> documents = new ArrayList<>();
126157
double totalTermCount = 0;
@@ -130,18 +161,20 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
130161
{
131162
var tf = docIterator.next();
132163
documents.add(tf);
133-
totalTermCount += tf.termCount;
164+
totalTermCount += tf.termCount();
134165
}
166+
135167
if (documents.isEmpty())
136168
return CloseableIterator.emptyIterator();
137169

138170
// Calculate average document length
139171
double avgDocLength = totalTermCount / documents.size();
140172

141-
// Calculate BM25 scores
142-
var scoredDocs = new ArrayList<PrimaryKeyWithScore>(documents.size());
143-
for (var doc : documents)
173+
// Calculate BM25 scores. Uses a nodequeue that avoids additional allocations and has heap time complexity
174+
var nodeQueue = new NodeQueue(new BoundedLongHeap(documents.size()), NodeQueue.Order.MAX_HEAP);
175+
for (int i = 0; i < documents.size(); i++)
144176
{
177+
var doc = documents.get(i);
145178
double score = 0.0;
146179
for (var queryTerm : queryTerms)
147180
{
@@ -150,45 +183,55 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
150183
// we shouldn't have more hits for a term than we counted total documents
151184
assert df <= docStats.docCount : String.format("df=%d, totalDocs=%d", df, docStats.docCount);
152185

153-
double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount / avgDocLength));
186+
double normalizedTf = tf / (tf + K1 * (1 - B + B * doc.termCount() / avgDocLength));
154187
double idf = Math.log(1 + (docStats.docCount - df + 0.5) / (df + 0.5));
155188
double deltaScore = normalizedTf * idf;
156189
assert deltaScore >= 0 : String.format("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f",
157-
tf, df, doc.termCount, docStats.docCount, deltaScore);
190+
tf, df, doc.termCount(), docStats.docCount, deltaScore);
158191
score += deltaScore;
159192
}
160-
if (source instanceof Memtable)
161-
scoredDocs.add(new PrimaryKeyWithScore(indexContext, (Memtable) source, doc.pk, (float) score));
162-
else if (source instanceof SSTableId)
163-
scoredDocs.add(new PrimaryKeyWithScore(indexContext, (SSTableId) source, doc.pk, (float) score));
164-
else
165-
throw new IllegalArgumentException("Invalid source " + source.getClass());
193+
nodeQueue.push(i, (float) score);
166194
}
167195

168-
// sort by score (PKWS implements Comparator correctly for us)
169-
Collections.sort(scoredDocs);
196+
return new NodeQueueDocTFIterator(nodeQueue, documents, indexContext, source, docIterator);
197+
}
170198

171-
return new CloseableIterator<>()
199+
private static class NodeQueueDocTFIterator extends AbstractIterator<PrimaryKeyWithSortKey>
200+
{
201+
private final NodeQueue nodeQueue;
202+
private final List<DocTF> documents;
203+
private final IndexContext indexContext;
204+
private final Object source;
205+
private final CloseableIterator<DocTF> docIterator;
206+
207+
NodeQueueDocTFIterator(NodeQueue nodeQueue, List<DocTF> documents, IndexContext indexContext, Object source, CloseableIterator<DocTF> docIterator)
172208
{
173-
private final Iterator<PrimaryKeyWithScore> iterator = scoredDocs.iterator();
209+
this.nodeQueue = nodeQueue;
210+
this.documents = documents;
211+
this.indexContext = indexContext;
212+
this.source = source;
213+
this.docIterator = docIterator;
214+
}
174215

175-
@Override
176-
public boolean hasNext()
177-
{
178-
return iterator.hasNext();
179-
}
216+
@Override
217+
protected PrimaryKeyWithSortKey computeNext()
218+
{
219+
if (nodeQueue.size() == 0)
220+
return endOfData();
180221

181-
@Override
182-
public PrimaryKeyWithSortKey next()
183-
{
184-
return iterator.next();
185-
}
222+
var score = nodeQueue.topScore();
223+
var node = nodeQueue.pop();
224+
var doc = documents.get(node);
225+
if (source instanceof Memtable)
226+
return doc.primaryKey(indexContext, (Memtable) source, score);
227+
else
228+
return doc.primaryKey(indexContext, (SSTableId<?>) source, score);
229+
}
186230

187-
@Override
188-
public void close()
189-
{
190-
FileUtils.closeQuietly(docIterator);
191-
}
192-
};
231+
@Override
232+
public void close()
233+
{
234+
FileUtils.closeQuietly(docIterator);
235+
}
193236
}
194237
}

0 commit comments

Comments
 (0)