Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125103.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125103
summary: Fix LTR query feature with phrases (and two-phase) queries
area: Ranking
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.apache.lucene.search.Weight;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

Expand All @@ -25,52 +24,52 @@
* respective feature name.
*/
public class QueryFeatureExtractor implements FeatureExtractor {

private final List<String> featureNames;
private final List<Weight> weights;
private final List<Scorer> scorers;
private DisjunctionDISIApproximation rankerIterator;

private final DisiPriorityQueue subScorers;
private DisjunctionDISIApproximation approximation;

public QueryFeatureExtractor(List<String> featureNames, List<Weight> weights) {
if (featureNames.size() != weights.size()) {
throw new IllegalArgumentException("[featureNames] and [weights] must be the same size.");
}
this.featureNames = featureNames;
this.weights = weights;
this.scorers = new ArrayList<>(weights.size());
this.subScorers = new DisiPriorityQueue(weights.size());
}

@Override
public void setNextReader(LeafReaderContext segmentContext) throws IOException {
DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size());
scorers.clear();
for (Weight weight : weights) {
subScorers.clear();
for (int i = 0; i < weights.size(); i++) {
var weight = weights.get(i);
if (weight == null) {
scorers.add(null);
continue;
}
Scorer scorer = weight.scorer(segmentContext);
if (scorer != null) {
disiPriorityQueue.add(new DisiWrapper(scorer));
subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i)));
}
scorers.add(scorer);
}

rankerIterator = disiPriorityQueue.size() > 0 ? new DisjunctionDISIApproximation(disiPriorityQueue) : null;
approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null;
}

@Override
public void addFeatures(Map<String, Object> featureMap, int docId) throws IOException {
if (rankerIterator == null) {
if (approximation == null || approximation.docID() > docId) {
return;
}

rankerIterator.advance(docId);
for (int i = 0; i < featureNames.size(); i++) {
Scorer scorer = scorers.get(i);
// Do we have a scorer, and does it match the provided document?
if (scorer != null && scorer.docID() == docId) {
featureMap.put(featureNames.get(i), scorer.score());
if (approximation.docID() < docId) {
approximation.advance(docId);
}
if (approximation.docID() != docId) {
return;
}
var w = (FeatureDisiWrapper) subScorers.topList();
for (; w != null; w = (FeatureDisiWrapper) w.next) {
if (w.twoPhaseView == null || w.twoPhaseView.matches()) {
featureMap.put(w.featureName, w.scorer.score());
}
}
}
Expand All @@ -80,4 +79,12 @@ public List<String> featureNames() {
return featureNames;
}

private static class FeatureDisiWrapper extends DisiWrapper {
final String featureName;

FeatureDisiWrapper(Scorer scorer, String featureName) {
super(scorer);
this.featureName = featureName;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.apache.lucene.document.IntField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
Expand Down Expand Up @@ -43,13 +43,11 @@

public class QueryFeatureExtractorTests extends AbstractBuilderTestCase {

private Directory dir;
private IndexReader reader;
private IndexSearcher searcher;

private void addDocs(String[] textValues, int[] numberValues) throws IOException {
dir = newDirectory();
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir)) {
private IndexReader addDocs(Directory dir, String[] textValues, int[] numberValues) throws IOException {
var config = newIndexWriterConfig();
// override the merge policy to ensure that docs remain in the same ingestion order
config.setMergePolicy(newLogMergePolicy(random()));
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir, config)) {
for (int i = 0; i < textValues.length; i++) {
Document doc = new Document();
doc.add(newTextField(TEXT_FIELD_NAME, textValues[i], Field.Store.NO));
Expand All @@ -59,98 +57,119 @@ private void addDocs(String[] textValues, int[] numberValues) throws IOException
indexWriter.flush();
}
}
reader = indexWriter.getReader();
return indexWriter.getReader();
}
searcher = newSearcher(reader);
searcher.setSimilarity(new ClassicSimilarity());
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98127")
public void testQueryExtractor() throws IOException {
addDocs(
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
);
QueryRewriteContext ctx = createQueryRewriteContext();
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")))
.rewrite(ctx),
new QueryExtractorBuilder(
"number_score",
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_none",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_missing_field",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
).rewrite(ctx)
);
SearchExecutionContext dummySEC = createSearchExecutionContext();
List<Weight> weights = new ArrayList<>();
List<String> featureNames = new ArrayList<>();
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
weights.add(weight);
featureNames.add(qeb.featureName());
}
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
queryFeatureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
queryFeatureExtractor.addFeatures(featureMap, i);
extractedFeatures.add(featureMap);
try (var dir = newDirectory()) {
try (
var reader = addDocs(
dir,
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
)
) {
var searcher = newSearcher(reader);
searcher.setSimilarity(new ClassicSimilarity());
QueryRewriteContext ctx = createQueryRewriteContext();
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
new QueryExtractorBuilder(
"text_score",
QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox"))
).rewrite(ctx),
new QueryExtractorBuilder(
"number_score",
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_none",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
).rewrite(ctx),
new QueryExtractorBuilder(
"matching_missing_field",
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
).rewrite(ctx),
new QueryExtractorBuilder(
"phrase_score",
QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox"))
).rewrite(ctx)
);
SearchExecutionContext dummySEC = createSearchExecutionContext();
List<Weight> weights = new ArrayList<>();
List<String> featureNames = new ArrayList<>();
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
weights.add(weight);
featureNames.add(qeb.featureName());
}
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
queryFeatureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
queryFeatureExtractor.addFeatures(featureMap, i);
extractedFeatures.add(featureMap);
}
}
assertThat(extractedFeatures, hasSize(4));
// Should never add features for queries that don't match a document or on documents where the field is missing
for (Map<String, Object> features : extractedFeatures) {
assertThat(features, not(hasKey("matching_none")));
assertThat(features, not(hasKey("matching_missing_field")));
}
// First two only match the text field
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
assertThat(extractedFeatures.get(0), not(hasKey("phrase_score")));
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f));

// Only matches the range query
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
assertThat(extractedFeatures.get(2), not(hasKey("phrase_score")));

// No query matches
assertThat(extractedFeatures.get(3), anEmptyMap());
}
}
assertThat(extractedFeatures, hasSize(4));
// Should never add features for queries that don't match a document or on documents where the field is missing
for (Map<String, Object> features : extractedFeatures) {
assertThat(features, not(hasKey("matching_none")));
assertThat(features, not(hasKey("matching_missing_field")));
}
// First two only match the text field
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
// Only matches the range query
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
// No query matches
assertThat(extractedFeatures.get(3), anEmptyMap());
reader.close();
dir.close();
}

public void testEmptyDisiPriorityQueue() throws IOException {
addDocs(
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
);
try (var dir = newDirectory()) {
var config = newIndexWriterConfig();
config.setMergePolicy(NoMergePolicy.INSTANCE);
try (
var reader = addDocs(
dir,
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
new int[] { 5, 10, 12, 11 }
)
) {

// Scorers returned by weights are null
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
var searcher = newSearcher(reader);
searcher.setSimilarity(new ClassicSimilarity());

QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
// Scorers returned by weights are null
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();

for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
featureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
featureExtractor.addFeatures(featureMap, i);
assertThat(featureMap, anEmptyMap());
QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
int maxDoc = leafReaderContext.reader().maxDoc();
featureExtractor.setNextReader(leafReaderContext);
for (int i = 0; i < maxDoc; i++) {
Map<String, Object> featureMap = new HashMap<>();
featureExtractor.addFeatures(featureMap, i);
assertThat(featureMap, anEmptyMap());
}
}
}
}

reader.close();
dir.close();
}
}