21
21
import java .nio .ByteBuffer ;
22
22
import java .util .ArrayList ;
23
23
import java .util .Collection ;
24
- import java .util .Collections ;
25
24
import java .util .HashMap ;
26
- import java .util .Iterator ;
27
25
import java .util .List ;
28
26
import java .util .Map ;
29
27
30
28
import javax .annotation .Nullable ;
31
29
30
+ import io .github .jbellis .jvector .graph .NodeQueue ;
31
+ import io .github .jbellis .jvector .util .BoundedLongHeap ;
32
32
import org .apache .cassandra .db .memtable .Memtable ;
33
33
import org .apache .cassandra .db .rows .Cell ;
34
34
import org .apache .cassandra .index .sai .IndexContext ;
35
35
import org .apache .cassandra .index .sai .analyzer .AbstractAnalyzer ;
36
36
import org .apache .cassandra .io .sstable .SSTableId ;
37
37
import org .apache .cassandra .io .util .FileUtils ;
38
+ import org .apache .cassandra .utils .AbstractIterator ;
38
39
import org .apache .cassandra .utils .CloseableIterator ;
39
40
40
41
public class BM25Utils
@@ -60,15 +61,28 @@ public DocStats(Map<ByteBuffer, Long> frequencies, long docCount)
60
61
}
61
62
62
63
/**
63
- * Term frequencies within a single document. All instances of a term are counted.
64
+ * Term frequencies within a single document. All instances of a term are counted. Allows us to optimize for
65
+ * the sstable use case, which is able to skip some reads from disk as well as some memory allocations.
64
66
*/
65
- public static class DocTF
67
+ public interface DocTF
68
+ {
69
+ int getTermFrequency (ByteBuffer term );
70
+ int termCount ();
71
+ PrimaryKeyWithSortKey primaryKey (IndexContext context , Memtable source , float score );
72
+ PrimaryKeyWithSortKey primaryKey (IndexContext context , SSTableId <?> source , float score );
73
+ }
74
+
75
+ /**
76
+ * Term frequencies within a single document. All instances of a term are counted. It is eager in that the
77
+ * PrimaryKey is already created.
78
+ */
79
+ public static class EagerDocTF implements DocTF
66
80
{
67
81
private final PrimaryKey pk ;
68
82
private final Map <ByteBuffer , Integer > frequencies ;
69
83
private final int termCount ;
70
84
71
- public DocTF (PrimaryKey pk , int termCount , Map <ByteBuffer , Integer > frequencies )
85
+ public EagerDocTF (PrimaryKey pk , int termCount , Map <ByteBuffer , Integer > frequencies )
72
86
{
73
87
this .pk = pk ;
74
88
this .frequencies = frequencies ;
@@ -80,6 +94,21 @@ public int getTermFrequency(ByteBuffer term)
80
94
return frequencies .getOrDefault (term , 0 );
81
95
}
82
96
97
+ public int termCount ()
98
+ {
99
+ return termCount ;
100
+ }
101
+
102
+ public PrimaryKeyWithSortKey primaryKey (IndexContext context , Memtable source , float score )
103
+ {
104
+ return new PrimaryKeyWithScore (context , source , pk , score );
105
+ }
106
+
107
+ public PrimaryKeyWithSortKey primaryKey (IndexContext context , SSTableId <?> source , float score )
108
+ {
109
+ return new PrimaryKeyWithScore (context , source , pk , score );
110
+ }
111
+
83
112
@ Nullable
84
113
public static DocTF createFromDocument (PrimaryKey pk ,
85
114
Cell <?> cell ,
@@ -111,7 +140,7 @@ public static DocTF createFromDocument(PrimaryKey pk,
111
140
if (queryTerms .size () > frequencies .size ())
112
141
return null ;
113
142
114
- return new DocTF (pk , count , frequencies );
143
+ return new EagerDocTF (pk , count , frequencies );
115
144
}
116
145
}
117
146
@@ -121,6 +150,8 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
121
150
IndexContext indexContext ,
122
151
Object source )
123
152
{
153
+ assert source instanceof Memtable || source instanceof SSTableId : "Invalid source " + source .getClass ();
154
+
124
155
// data structures for document stats and frequencies
125
156
ArrayList <DocTF > documents = new ArrayList <>();
126
157
double totalTermCount = 0 ;
@@ -130,18 +161,20 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
130
161
{
131
162
var tf = docIterator .next ();
132
163
documents .add (tf );
133
- totalTermCount += tf .termCount ;
164
+ totalTermCount += tf .termCount () ;
134
165
}
166
+
135
167
if (documents .isEmpty ())
136
168
return CloseableIterator .emptyIterator ();
137
169
138
170
// Calculate average document length
139
171
double avgDocLength = totalTermCount / documents .size ();
140
172
141
- // Calculate BM25 scores
142
- var scoredDocs = new ArrayList < PrimaryKeyWithScore >( documents .size ());
143
- for (var doc : documents )
173
+ // Calculate BM25 scores. Uses a nodequeue that avoids additional allocations and has heap time complexity
174
+ var nodeQueue = new NodeQueue ( new BoundedLongHeap ( documents .size ()), NodeQueue . Order . MAX_HEAP );
175
+ for (int i = 0 ; i < documents . size (); i ++ )
144
176
{
177
+ var doc = documents .get (i );
145
178
double score = 0.0 ;
146
179
for (var queryTerm : queryTerms )
147
180
{
@@ -150,45 +183,55 @@ public static CloseableIterator<PrimaryKeyWithSortKey> computeScores(CloseableIt
150
183
// we shouldn't have more hits for a term than we counted total documents
151
184
assert df <= docStats .docCount : String .format ("df=%d, totalDocs=%d" , df , docStats .docCount );
152
185
153
- double normalizedTf = tf / (tf + K1 * (1 - B + B * doc .termCount / avgDocLength ));
186
+ double normalizedTf = tf / (tf + K1 * (1 - B + B * doc .termCount () / avgDocLength ));
154
187
double idf = Math .log (1 + (docStats .docCount - df + 0.5 ) / (df + 0.5 ));
155
188
double deltaScore = normalizedTf * idf ;
156
189
assert deltaScore >= 0 : String .format ("BM25 score for tf=%d, df=%d, tc=%d, totalDocs=%d is %f" ,
157
- tf , df , doc .termCount , docStats .docCount , deltaScore );
190
+ tf , df , doc .termCount () , docStats .docCount , deltaScore );
158
191
score += deltaScore ;
159
192
}
160
- if (source instanceof Memtable )
161
- scoredDocs .add (new PrimaryKeyWithScore (indexContext , (Memtable ) source , doc .pk , (float ) score ));
162
- else if (source instanceof SSTableId )
163
- scoredDocs .add (new PrimaryKeyWithScore (indexContext , (SSTableId ) source , doc .pk , (float ) score ));
164
- else
165
- throw new IllegalArgumentException ("Invalid source " + source .getClass ());
193
+ nodeQueue .push (i , (float ) score );
166
194
}
167
195
168
- // sort by score (PKWS implements Comparator correctly for us)
169
- Collections . sort ( scoredDocs );
196
+ return new NodeQueueDocTFIterator ( nodeQueue , documents , indexContext , source , docIterator );
197
+ }
170
198
171
- return new CloseableIterator <>()
199
+ private static class NodeQueueDocTFIterator extends AbstractIterator <PrimaryKeyWithSortKey >
200
+ {
201
+ private final NodeQueue nodeQueue ;
202
+ private final List <DocTF > documents ;
203
+ private final IndexContext indexContext ;
204
+ private final Object source ;
205
+ private final CloseableIterator <DocTF > docIterator ;
206
+
207
+ NodeQueueDocTFIterator (NodeQueue nodeQueue , List <DocTF > documents , IndexContext indexContext , Object source , CloseableIterator <DocTF > docIterator )
172
208
{
173
- private final Iterator <PrimaryKeyWithScore > iterator = scoredDocs .iterator ();
209
+ this .nodeQueue = nodeQueue ;
210
+ this .documents = documents ;
211
+ this .indexContext = indexContext ;
212
+ this .source = source ;
213
+ this .docIterator = docIterator ;
214
+ }
174
215
175
- @ Override
176
- public boolean hasNext ()
177
- {
178
- return iterator . hasNext ();
179
- }
216
+ @ Override
217
+ protected PrimaryKeyWithSortKey computeNext ()
218
+ {
219
+ if ( nodeQueue . size () == 0 )
220
+ return endOfData ();
180
221
181
- @ Override
182
- public PrimaryKeyWithSortKey next ()
183
- {
184
- return iterator .next ();
185
- }
222
+ var score = nodeQueue .topScore ();
223
+ var node = nodeQueue .pop ();
224
+ var doc = documents .get (node );
225
+ if (source instanceof Memtable )
226
+ return doc .primaryKey (indexContext , (Memtable ) source , score );
227
+ else
228
+ return doc .primaryKey (indexContext , (SSTableId <?>) source , score );
229
+ }
186
230
187
- @ Override
188
- public void close ()
189
- {
190
- FileUtils .closeQuietly (docIterator );
191
- }
192
- };
231
+ @ Override
232
+ public void close ()
233
+ {
234
+ FileUtils .closeQuietly (docIterator );
235
+ }
193
236
}
194
237
}
0 commit comments