Skip to content

Commit f59ff92

Browse files
authored
feat: implement asBulkSimScorer on FeatureFields's SimScorers (#15137)
1 parent 839425e commit f59ff92

File tree

3 files changed

+149
-8
lines changed

3 files changed

+149
-8
lines changed

lucene/CHANGES.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,10 @@ Optimizations
270270

271271
* GITHUB#15045: Use FixedBitSet#cardinality for counting liveDocs in CheckIndex (Zhang Chao)
272272

273+
274+
* GITHUB#15117: Score computations are now more reliably vectorized in FeatureField Scorer's
275+
(Aditya Teltia)
276+
273277
* GITHUB#15116, GITHUB#15138: Rewrite of the GroupVInt optimization without lambdas, varhandles
274278
and no code in subclasses. The optimization now auto-detects if an IndexInput supports random access
275279
and uses an optimized branchless approach. Any subclasses that have implemented the optimized method

lucene/core/src/java/org/apache/lucene/document/FeatureField.java

Lines changed: 80 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.lucene.search.Query;
3535
import org.apache.lucene.search.SortField;
3636
import org.apache.lucene.search.similarities.BM25Similarity;
37+
import org.apache.lucene.search.similarities.Similarity;
3738
import org.apache.lucene.search.similarities.Similarity.SimScorer;
3839

3940
/**
@@ -262,9 +263,27 @@ static final class LinearFunction extends FeatureFunction {
262263
@Override
263264
SimScorer scorer(float w) {
264265
return new SimScorer() {
266+
private float doScore(float f) {
267+
return (w * f);
268+
}
269+
265270
@Override
266271
public float score(float freq, long norm) {
267-
return (w * decodeFeatureValue(freq));
272+
float f = decodeFeatureValue(freq);
273+
return doScore(f);
274+
}
275+
276+
@Override
277+
public Similarity.BulkSimScorer asBulkSimScorer() {
278+
return new Similarity.BulkSimScorer() {
279+
@Override
280+
public void score(int size, float[] freqs, long[] norms, float[] scores) {
281+
for (int i = 0; i < size; ++i) {
282+
float f = decodeFeatureValue(freqs[i]);
283+
scores[i] = doScore(f);
284+
}
285+
}
286+
};
268287
}
269288
};
270289
}
@@ -333,9 +352,27 @@ public String toString() {
333352
@Override
334353
SimScorer scorer(float weight) {
335354
return new SimScorer() {
355+
private float doScore(float f) {
356+
return (float) (weight * Math.log(scalingFactor + f));
357+
}
358+
336359
@Override
337360
public float score(float freq, long norm) {
338-
return (float) (weight * Math.log(scalingFactor + decodeFeatureValue(freq)));
361+
float f = decodeFeatureValue(freq);
362+
return doScore(f);
363+
}
364+
365+
@Override
366+
public Similarity.BulkSimScorer asBulkSimScorer() {
367+
return new Similarity.BulkSimScorer() {
368+
@Override
369+
public void score(int size, float[] freqs, long[] norms, float[] scores) {
370+
for (int i = 0; i < size; ++i) {
371+
float f = decodeFeatureValue(freqs[i]);
372+
scores[i] = doScore(f);
373+
}
374+
}
375+
};
339376
}
340377
};
341378
}
@@ -405,14 +442,32 @@ SimScorer scorer(float weight) {
405442
}
406443
final float pivot = this.pivot; // unbox
407444
return new SimScorer() {
408-
@Override
409-
public float score(float freq, long norm) {
410-
float f = decodeFeatureValue(freq);
445+
446+
private float doScore(float f) {
411447
// should be f / (f + k) but we rewrite it to
412448
// 1 - k / (f + k) to make sure it doesn't decrease
413449
// with f in spite of rounding
414450
return weight * (1 - pivot / (f + pivot));
415451
}
452+
453+
@Override
454+
public float score(float freq, long norm) {
455+
float f = decodeFeatureValue(freq);
456+
return doScore(f);
457+
}
458+
459+
@Override
460+
public Similarity.BulkSimScorer asBulkSimScorer() {
461+
return new Similarity.BulkSimScorer() {
462+
@Override
463+
public void score(int size, float[] freqs, long[] norms, float[] scores) {
464+
for (int i = 0; i < size; ++i) {
465+
float f = decodeFeatureValue(freqs[i]);
466+
scores[i] = doScore(f);
467+
}
468+
}
469+
};
470+
}
416471
};
417472
}
418473

@@ -469,14 +524,31 @@ public String toString() {
469524
@Override
470525
SimScorer scorer(float weight) {
471526
return new SimScorer() {
472-
@Override
473-
public float score(float freq, long norm) {
474-
float f = decodeFeatureValue(freq);
527+
private float doScore(float f) {
475528
// should be f^a / (f^a + k^a) but we rewrite it to
476529
// 1 - k^a / (f + k^a) to make sure it doesn't decrease
477530
// with f in spite of rounding
478531
return (float) (weight * (1 - pivotPa / (Math.pow(f, a) + pivotPa)));
479532
}
533+
534+
@Override
535+
public float score(float freq, long norm) {
536+
float f = decodeFeatureValue(freq);
537+
return doScore(f);
538+
}
539+
540+
@Override
541+
public Similarity.BulkSimScorer asBulkSimScorer() {
542+
return new Similarity.BulkSimScorer() {
543+
@Override
544+
public void score(int size, float[] freqs, long[] norms, float[] scores) {
545+
for (int i = 0; i < size; ++i) {
546+
float f = decodeFeatureValue(freqs[i]);
547+
scores[i] = doScore(f);
548+
}
549+
}
550+
};
551+
}
480552
};
481553
}
482554

lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import java.io.IOException;
2424
import java.util.List;
25+
import java.util.Random;
2526
import org.apache.lucene.document.Field.Store;
2627
import org.apache.lucene.index.DirectoryReader;
2728
import org.apache.lucene.index.LeafReaderContext;
@@ -37,11 +38,14 @@
3738
import org.apache.lucene.search.TopDocs;
3839
import org.apache.lucene.search.Weight;
3940
import org.apache.lucene.search.similarities.BM25Similarity;
41+
import org.apache.lucene.search.similarities.Similarity;
4042
import org.apache.lucene.search.similarities.Similarity.SimScorer;
4143
import org.apache.lucene.store.Directory;
4244
import org.apache.lucene.tests.index.RandomIndexWriter;
4345
import org.apache.lucene.tests.search.QueryUtils;
4446
import org.apache.lucene.tests.util.LuceneTestCase;
47+
import org.apache.lucene.tests.util.TestUtil;
48+
import org.apache.lucene.util.ArrayUtil;
4549
import org.apache.lucene.util.BytesRef;
4650
import org.apache.lucene.util.IOUtils;
4751

@@ -528,4 +532,65 @@ public void testStoreTermVectors() throws Exception {
528532

529533
IOUtils.close(reader, dir);
530534
}
535+
536+
public void testLinearBulkScorer() {
537+
FeatureField.LinearFunction func = new FeatureField.LinearFunction();
538+
SimScorer scorer = func.scorer(2f); // weight = 2
539+
Similarity.BulkSimScorer bulkScorer = scorer.asBulkSimScorer();
540+
doTestBulkScorer(scorer, bulkScorer);
541+
}
542+
543+
public void testLogBulkScorer() {
544+
FeatureField.LogFunction func = new FeatureField.LogFunction(4.5f);
545+
SimScorer scorer = func.scorer(3f); // weight = 3
546+
Similarity.BulkSimScorer bulkScorer = scorer.asBulkSimScorer();
547+
doTestBulkScorer(scorer, bulkScorer);
548+
}
549+
550+
public void testSaturationBulkScorer() {
551+
FeatureField.SaturationFunction func = new FeatureField.SaturationFunction("foo", "bar", 4.5f);
552+
SimScorer scorer = func.scorer(3f);
553+
Similarity.BulkSimScorer bulkScorer = scorer.asBulkSimScorer();
554+
doTestBulkScorer(scorer, bulkScorer);
555+
}
556+
557+
public void testSigmoidBulkScorer() {
558+
FeatureField.SigmoidFunction func = new FeatureField.SigmoidFunction(4.5f, 0.6f);
559+
SimScorer scorer = func.scorer(3f);
560+
Similarity.BulkSimScorer bulkScorer = scorer.asBulkSimScorer();
561+
doTestBulkScorer(scorer, bulkScorer);
562+
}
563+
564+
private void doTestBulkScorer(SimScorer scorer, Similarity.BulkSimScorer bulkScorer) {
565+
Random random = random();
566+
int iters = atLeast(3);
567+
float[] freqs = new float[0];
568+
long[] norms = new long[0];
569+
float[] scores = new float[0];
570+
571+
for (int iter = 0; iter < iters; ++iter) {
572+
int size = TestUtil.nextInt(random, 0, 200);
573+
if (size > freqs.length) {
574+
freqs = new float[ArrayUtil.oversize(size, Float.BYTES)];
575+
norms = new long[freqs.length];
576+
scores = new float[freqs.length];
577+
}
578+
for (int i = 0; i < size; ++i) {
579+
freqs[i] = TestUtil.nextInt(random, 1, 1000); // freq values
580+
norms[i] = TestUtil.nextLong(random, 1, 255); // norms in byte range
581+
}
582+
583+
float[] expected = new float[size];
584+
for (int i = 0; i < size; ++i) {
585+
expected[i] = scorer.score(freqs[i], norms[i]);
586+
}
587+
588+
bulkScorer.score(size, freqs, norms, scores);
589+
590+
assertArrayEquals(
591+
ArrayUtil.copyOfSubArray(expected, 0, size),
592+
ArrayUtil.copyOfSubArray(scores, 0, size),
593+
0f);
594+
}
595+
}
531596
}

0 commit comments

Comments
 (0)