diff --git a/docs/changelog/125103.yaml b/docs/changelog/125103.yaml new file mode 100644 index 0000000000000..da5d025e77869 --- /dev/null +++ b/docs/changelog/125103.yaml @@ -0,0 +1,5 @@ +pr: 125103 +summary: Fix LTR query feature with phrases (and two-phase) queries +area: Ranking +type: bug +issues: [] diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java index 26d5125c94c32..2942a1286c443 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractor.java @@ -15,7 +15,6 @@ import org.apache.lucene.search.Weight; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -25,11 +24,11 @@ * respective feature name. */ public class QueryFeatureExtractor implements FeatureExtractor { - private final List featureNames; private final List weights; - private final List scorers; - private DisjunctionDISIApproximation rankerIterator; + + private final DisiPriorityQueue subScorers; + private DisjunctionDISIApproximation approximation; public QueryFeatureExtractor(List featureNames, List weights) { if (featureNames.size() != weights.size()) { @@ -37,40 +36,40 @@ public QueryFeatureExtractor(List featureNames, List weights) { } this.featureNames = featureNames; this.weights = weights; - this.scorers = new ArrayList<>(weights.size()); + this.subScorers = new DisiPriorityQueue(weights.size()); } @Override public void setNextReader(LeafReaderContext segmentContext) throws IOException { - DisiPriorityQueue disiPriorityQueue = new DisiPriorityQueue(weights.size()); - scorers.clear(); - for (Weight weight : weights) { + subScorers.clear(); + for (int i = 0; i < weights.size(); i++) { + var weight = weights.get(i); if (weight == null) { - scorers.add(null); continue; } Scorer scorer = weight.scorer(segmentContext); if (scorer != null) { - disiPriorityQueue.add(new DisiWrapper(scorer, false)); + subScorers.add(new FeatureDisiWrapper(scorer, featureNames.get(i))); } - scorers.add(scorer); } - - rankerIterator = disiPriorityQueue.size() > 0 ? new DisjunctionDISIApproximation(disiPriorityQueue) : null; + approximation = subScorers.size() > 0 ? new DisjunctionDISIApproximation(subScorers) : null; } @Override public void addFeatures(Map featureMap, int docId) throws IOException { - if (rankerIterator == null) { + if (approximation == null || approximation.docID() > docId) { return; } - - rankerIterator.advance(docId); - for (int i = 0; i < featureNames.size(); i++) { - Scorer scorer = scorers.get(i); - // Do we have a scorer, and does it match the provided document? - if (scorer != null && scorer.docID() == docId) { - featureMap.put(featureNames.get(i), scorer.score()); + if (approximation.docID() < docId) { + approximation.advance(docId); + } + if (approximation.docID() != docId) { + return; + } + var w = (FeatureDisiWrapper) subScorers.topList(); + for (; w != null; w = (FeatureDisiWrapper) w.next) { + if (w.twoPhaseView == null || w.twoPhaseView.matches()) { + featureMap.put(w.featureName, w.scorable.score()); } } } @@ -80,4 +79,12 @@ public List featureNames() { return featureNames; } + private static class FeatureDisiWrapper extends DisiWrapper { + final String featureName; + + FeatureDisiWrapper(Scorer scorer, String featureName) { + super(scorer, false); + this.featureName = featureName; + } + } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java index 3b25a266bf412..fc935ba1ae0e6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java @@ -12,7 +12,7 @@ import org.apache.lucene.document.IntField; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Weight; @@ -43,13 +43,11 @@ public class QueryFeatureExtractorTests extends AbstractBuilderTestCase { - private Directory dir; - private IndexReader reader; - private IndexSearcher searcher; - - private void addDocs(String[] textValues, int[] numberValues) throws IOException { - dir = newDirectory(); - try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir)) { + private IndexReader addDocs(Directory dir, String[] textValues, int[] numberValues) throws IOException { + var config = newIndexWriterConfig(); + // override the merge policy to ensure that docs remain in the same ingestion order + config.setMergePolicy(newLogMergePolicy(random())); + try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir, config)) { for (int i = 0; i < textValues.length; i++) { Document doc = new Document(); doc.add(newTextField(TEXT_FIELD_NAME, textValues[i], Field.Store.NO)); @@ -59,98 +57,119 @@ private void addDocs(String[] textValues, int[] numberValues) throws IOException indexWriter.flush(); } } - reader = indexWriter.getReader(); + return indexWriter.getReader(); } - searcher = newSearcher(reader); - searcher.setSimilarity(new ClassicSimilarity()); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/98127") public void testQueryExtractor() throws IOException { - addDocs( - new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" }, - new int[] { 5, 10, 12, 11 } - ); - QueryRewriteContext ctx = createQueryRewriteContext(); - List queryExtractorBuilders = List.of( - new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox"))) - .rewrite(ctx), - new QueryExtractorBuilder( - "number_score", - QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12)) - ).rewrite(ctx), - new QueryExtractorBuilder( - "matching_none", - QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term")) - ).rewrite(ctx), - new QueryExtractorBuilder( - "matching_missing_field", - QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox")) - ).rewrite(ctx) - ); - SearchExecutionContext dummySEC = createSearchExecutionContext(); - List weights = new ArrayList<>(); - List featureNames = new ArrayList<>(); - for (QueryExtractorBuilder qeb : queryExtractorBuilders) { - Query q = qeb.query().getParsedQuery().toQuery(dummySEC); - Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f); - weights.add(weight); - featureNames.add(qeb.featureName()); - } - QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights); - List> extractedFeatures = new ArrayList<>(); - for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) { - int maxDoc = leafReaderContext.reader().maxDoc(); - queryFeatureExtractor.setNextReader(leafReaderContext); - for (int i = 0; i < maxDoc; i++) { - Map featureMap = new HashMap<>(); - queryFeatureExtractor.addFeatures(featureMap, i); - extractedFeatures.add(featureMap); + try (var dir = newDirectory()) { + try ( + var reader = addDocs( + dir, + new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" }, + new int[] { 5, 10, 12, 11 } + ) + ) { + var searcher = newSearcher(reader); + searcher.setSimilarity(new ClassicSimilarity()); + QueryRewriteContext ctx = createQueryRewriteContext(); + List queryExtractorBuilders = List.of( + new QueryExtractorBuilder( + "text_score", + QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")) + ).rewrite(ctx), + new QueryExtractorBuilder( + "number_score", + QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12)) + ).rewrite(ctx), + new QueryExtractorBuilder( + "matching_none", + QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term")) + ).rewrite(ctx), + new QueryExtractorBuilder( + "matching_missing_field", + QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox")) + ).rewrite(ctx), + new QueryExtractorBuilder( + "phrase_score", + QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox")) + ).rewrite(ctx) + ); + SearchExecutionContext dummySEC = createSearchExecutionContext(); + List weights = new ArrayList<>(); + List featureNames = new ArrayList<>(); + for (QueryExtractorBuilder qeb : queryExtractorBuilders) { + Query q = qeb.query().getParsedQuery().toQuery(dummySEC); + Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f); + weights.add(weight); + featureNames.add(qeb.featureName()); + } + QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights); + List> extractedFeatures = new ArrayList<>(); + for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) { + int maxDoc = leafReaderContext.reader().maxDoc(); + queryFeatureExtractor.setNextReader(leafReaderContext); + for (int i = 0; i < maxDoc; i++) { + Map featureMap = new HashMap<>(); + queryFeatureExtractor.addFeatures(featureMap, i); + extractedFeatures.add(featureMap); + } + } + assertThat(extractedFeatures, hasSize(4)); + // Should never add features for queries that don't match a document or on documents where the field is missing + for (Map features : extractedFeatures) { + assertThat(features, not(hasKey("matching_none"))); + assertThat(features, not(hasKey("matching_missing_field"))); + } + // First two only match the text field + assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f)); + assertThat(extractedFeatures.get(0), not(hasKey("number_score"))); + assertThat(extractedFeatures.get(0), not(hasKey("phrase_score"))); + assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f)); + assertThat(extractedFeatures.get(1), not(hasKey("number_score"))); + assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f)); + + // Only matches the range query + assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f)); + assertThat(extractedFeatures.get(2), not(hasKey("text_score"))); + assertThat(extractedFeatures.get(2), not(hasKey("phrase_score"))); + + // No query matches + assertThat(extractedFeatures.get(3), anEmptyMap()); } } - assertThat(extractedFeatures, hasSize(4)); - // Should never add features for queries that don't match a document or on documents where the field is missing - for (Map features : extractedFeatures) { - assertThat(features, not(hasKey("matching_none"))); - assertThat(features, not(hasKey("matching_missing_field"))); - } - // First two only match the text field - assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f)); - assertThat(extractedFeatures.get(0), not(hasKey("number_score"))); - assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f)); - assertThat(extractedFeatures.get(1), not(hasKey("number_score"))); - // Only matches the range query - assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f)); - assertThat(extractedFeatures.get(2), not(hasKey("text_score"))); - // No query matches - assertThat(extractedFeatures.get(3), anEmptyMap()); - reader.close(); - dir.close(); } public void testEmptyDisiPriorityQueue() throws IOException { - addDocs( - new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" }, - new int[] { 5, 10, 12, 11 } - ); + try (var dir = newDirectory()) { + var config = newIndexWriterConfig(); + config.setMergePolicy(NoMergePolicy.INSTANCE); + try ( + var reader = addDocs( + dir, + new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" }, + new int[] { 5, 10, 12, 11 } + ) + ) { - // Scorers returned by weights are null - List featureNames = randomList(1, 10, ESTestCase::randomIdentifier); - List weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList(); + var searcher = newSearcher(reader); + searcher.setSimilarity(new ClassicSimilarity()); - QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights); + // Scorers returned by weights are null + List featureNames = randomList(1, 10, ESTestCase::randomIdentifier); + List weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList(); - for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) { - int maxDoc = leafReaderContext.reader().maxDoc(); - featureExtractor.setNextReader(leafReaderContext); - for (int i = 0; i < maxDoc; i++) { - Map featureMap = new HashMap<>(); - featureExtractor.addFeatures(featureMap, i); - assertThat(featureMap, anEmptyMap()); + QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights); + for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) { + int maxDoc = leafReaderContext.reader().maxDoc(); + featureExtractor.setNextReader(leafReaderContext); + for (int i = 0; i < maxDoc; i++) { + Map featureMap = new HashMap<>(); + featureExtractor.addFeatures(featureMap, i); + assertThat(featureMap, anEmptyMap()); + } + } } } - - reader.close(); - dir.close(); } }