Skip to content

Commit d78ef8a

Browse files
committed
Fix LTR query feature with phrases (and two-phase) queries
Query features should verify that docs match the two-phase iterator.
1 parent df7c20a commit d78ef8a

File tree

2 files changed

+33
-22
lines changed

2 files changed

+33
-22
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.apache.lucene.search.Weight;
1616

1717
import java.io.IOException;
18-
import java.util.ArrayList;
1918
import java.util.List;
2019
import java.util.Map;
2120

@@ -25,52 +24,52 @@
2524
* respective feature name.
2625
*/
2726
public class QueryFeatureExtractor implements FeatureExtractor {
28-
2927
private final List<String> featureNames;
3028
private final List<Weight> weights;
31-
private final List<Scorer> scorers;
32-
private DisjunctionDISIApproximation rankerIterator;
29+
30+
private final DisiPriorityQueue subScorers;
31+
private DisjunctionDISIApproximation approximation;
3332

3433
public QueryFeatureExtractor(List<String> featureNames, List<Weight> weights) {
3534
if (featureNames.size() != weights.size()) {
3635
throw new IllegalArgumentException("[featureNames] and [weights] must be the same size.");
3736
}
3837
this.featureNames = featureNames;
3938
this.weights = weights;
40-
this.scorers = new ArrayList<>(weights.size());
39+
this.subScorers = new DisiPriorityQueue(weights.size());
4140
}
4241

4342
@Override
4443
public void setNextReader(LeafReaderContext segmentContext) throws IOException {
45-
DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size());
46-
scorers.clear();
47-
for (Weight weight : weights) {
44+
subScorers.clear();
45+
for (int i = 0; i < weights.size(); i++) {
46+
var weight = weights.get(i);
4847
if (weight == null) {
49-
scorers.add(null);
5048
continue;
5149
}
5250
Scorer scorer = weight.scorer(segmentContext);
5351
if (scorer != null) {
54-
disiPriorityQueue.add(new DisiWrapper(scorer, false));
52+
subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i)));
5553
}
56-
scorers.add(scorer);
5754
}
58-
59-
rankerIterator = disiPriorityQueue.size() > 0 ? new DisjunctionDISIApproximation(disiPriorityQueue) : null;
55+
approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null;
6056
}
6157

6258
@Override
6359
public void addFeatures(Map<String, Object> featureMap, int docId) throws IOException {
64-
if (rankerIterator == null) {
60+
if (approximation == null || approximation.docID() > docId) {
6561
return;
6662
}
67-
68-
rankerIterator.advance(docId);
69-
for (int i = 0; i < featureNames.size(); i++) {
70-
Scorer scorer = scorers.get(i);
71-
// Do we have a scorer, and does it match the provided document?
72-
if (scorer != null && scorer.docID() == docId) {
73-
featureMap.put(featureNames.get(i), scorer.score());
63+
if (approximation.docID() < docId) {
64+
approximation.advance(docId);
65+
}
66+
if (approximation.docID() != docId) {
67+
return;
68+
}
69+
var w = (FeatureDisiWrapper) subScorers.topList();
70+
for (; w != null; w = (FeatureDisiWrapper) w.next) {
71+
if (w.twoPhaseView == null || w.twoPhaseView.matches()) {
72+
featureMap.put(w.featureName, w.scorable.score());
7473
}
7574
}
7675
}
@@ -80,4 +79,11 @@ public List<String> featureNames() {
8079
return featureNames;
8180
}
8281

82+
private static class FeatureDisiWrapper extends DisiWrapper {
83+
final String featureName;
84+
FeatureDisiWrapper(Scorer scorer, String featureName) {
85+
super(scorer, false);
86+
this.featureName = featureName;
87+
}
88+
}
8389
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ private void addDocs(String[] textValues, int[] numberValues) throws IOException
6565
searcher.setSimilarity(new ClassicSimilarity());
6666
}
6767

68-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98127")
6968
public void testQueryExtractor() throws IOException {
7069
addDocs(
7170
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
@@ -86,6 +85,8 @@ public void testQueryExtractor() throws IOException {
8685
new QueryExtractorBuilder(
8786
"matching_missing_field",
8887
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
88+
).rewrite(ctx),
89+
new QueryExtractorBuilder("phrase_score", QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox"))
8990
).rewrite(ctx)
9091
);
9192
SearchExecutionContext dummySEC = createSearchExecutionContext();
@@ -117,11 +118,15 @@ public void testQueryExtractor() throws IOException {
117118
// First two only match the text field
118119
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
119120
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
121+
assertThat(extractedFeatures.get(0), not(hasKey("phrase_score")));
120122
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
121123
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
124+
assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f));
125+
122126
// Only matches the range query
123127
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
124128
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
129+
assertThat(extractedFeatures.get(2), not(hasKey("phrase_score")));
125130
// No query matches
126131
assertThat(extractedFeatures.get(3), anEmptyMap());
127132
reader.close();

0 commit comments

Comments
 (0)