Skip to content

Commit 05efc92

Browse files
committed
solidify test
1 parent d78ef8a commit 05efc92

File tree

1 file changed

+98
-92
lines changed

1 file changed

+98
-92
lines changed

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/QueryFeatureExtractorTests.java

Lines changed: 98 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import org.apache.lucene.document.IntField;
1313
import org.apache.lucene.index.IndexReader;
1414
import org.apache.lucene.index.LeafReaderContext;
15-
import org.apache.lucene.search.IndexSearcher;
15+
import org.apache.lucene.index.NoMergePolicy;
1616
import org.apache.lucene.search.Query;
1717
import org.apache.lucene.search.ScoreMode;
1818
import org.apache.lucene.search.Weight;
@@ -43,13 +43,11 @@
4343

4444
public class QueryFeatureExtractorTests extends AbstractBuilderTestCase {
4545

46-
private Directory dir;
47-
private IndexReader reader;
48-
private IndexSearcher searcher;
49-
50-
private void addDocs(String[] textValues, int[] numberValues) throws IOException {
51-
dir = newDirectory();
52-
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir)) {
46+
private IndexReader addDocs(Directory dir, String[] textValues, int[] numberValues) throws IOException {
47+
var config = newIndexWriterConfig();
48+
// override the merge policy to ensure that docs remain in the same ingestion order
49+
config.setMergePolicy(newLogMergePolicy(random()));
50+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), dir, config)) {
5351
for (int i = 0; i < textValues.length; i++) {
5452
Document doc = new Document();
5553
doc.add(newTextField(TEXT_FIELD_NAME, textValues[i], Field.Store.NO));
@@ -59,103 +57,111 @@ private void addDocs(String[] textValues, int[] numberValues) throws IOException
5957
indexWriter.flush();
6058
}
6159
}
62-
reader = indexWriter.getReader();
60+
return indexWriter.getReader();
6361
}
64-
searcher = newSearcher(reader);
65-
searcher.setSimilarity(new ClassicSimilarity());
6662
}
6763

6864
public void testQueryExtractor() throws IOException {
69-
addDocs(
70-
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
71-
new int[] { 5, 10, 12, 11 }
72-
);
73-
QueryRewriteContext ctx = createQueryRewriteContext();
74-
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
75-
new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")))
76-
.rewrite(ctx),
77-
new QueryExtractorBuilder(
78-
"number_score",
79-
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
80-
).rewrite(ctx),
81-
new QueryExtractorBuilder(
82-
"matching_none",
83-
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
84-
).rewrite(ctx),
85-
new QueryExtractorBuilder(
86-
"matching_missing_field",
87-
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
88-
).rewrite(ctx),
89-
new QueryExtractorBuilder("phrase_score", QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox"))
90-
).rewrite(ctx)
91-
);
92-
SearchExecutionContext dummySEC = createSearchExecutionContext();
93-
List<Weight> weights = new ArrayList<>();
94-
List<String> featureNames = new ArrayList<>();
95-
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
96-
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
97-
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
98-
weights.add(weight);
99-
featureNames.add(qeb.featureName());
100-
}
101-
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
102-
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
103-
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
104-
int maxDoc = leafReaderContext.reader().maxDoc();
105-
queryFeatureExtractor.setNextReader(leafReaderContext);
106-
for (int i = 0; i < maxDoc; i++) {
107-
Map<String, Object> featureMap = new HashMap<>();
108-
queryFeatureExtractor.addFeatures(featureMap, i);
109-
extractedFeatures.add(featureMap);
65+
try (var dir = newDirectory()) {
66+
try (var reader = addDocs(
67+
dir,
68+
new String[]{"the quick brown fox", "the slow brown fox", "the grey dog", "yet another string"},
69+
new int[]{5, 10, 12, 11}
70+
)) {
71+
var searcher = newSearcher(reader);
72+
searcher.setSimilarity(new ClassicSimilarity());
73+
QueryRewriteContext ctx = createQueryRewriteContext();
74+
List<QueryExtractorBuilder> queryExtractorBuilders = List.of(
75+
new QueryExtractorBuilder("text_score", QueryProvider.fromParsedQuery(QueryBuilders.matchQuery(TEXT_FIELD_NAME, "quick fox")))
76+
.rewrite(ctx),
77+
new QueryExtractorBuilder(
78+
"number_score",
79+
QueryProvider.fromParsedQuery(QueryBuilders.rangeQuery(INT_FIELD_NAME).from(12).to(12))
80+
).rewrite(ctx),
81+
new QueryExtractorBuilder(
82+
"matching_none",
83+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery(TEXT_FIELD_NAME, "never found term"))
84+
).rewrite(ctx),
85+
new QueryExtractorBuilder(
86+
"matching_missing_field",
87+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("missing_text", "quick fox"))
88+
).rewrite(ctx),
89+
new QueryExtractorBuilder("phrase_score", QueryProvider.fromParsedQuery(QueryBuilders.matchPhraseQuery(TEXT_FIELD_NAME, "slow brown fox"))
90+
).rewrite(ctx)
91+
);
92+
SearchExecutionContext dummySEC = createSearchExecutionContext();
93+
List<Weight> weights = new ArrayList<>();
94+
List<String> featureNames = new ArrayList<>();
95+
for (QueryExtractorBuilder qeb : queryExtractorBuilders) {
96+
Query q = qeb.query().getParsedQuery().toQuery(dummySEC);
97+
Weight weight = searcher.rewrite(q).createWeight(searcher, ScoreMode.COMPLETE, 1f);
98+
weights.add(weight);
99+
featureNames.add(qeb.featureName());
100+
}
101+
QueryFeatureExtractor queryFeatureExtractor = new QueryFeatureExtractor(featureNames, weights);
102+
List<Map<String, Object>> extractedFeatures = new ArrayList<>();
103+
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
104+
int maxDoc = leafReaderContext.reader().maxDoc();
105+
queryFeatureExtractor.setNextReader(leafReaderContext);
106+
for (int i = 0; i < maxDoc; i++) {
107+
Map<String, Object> featureMap = new HashMap<>();
108+
queryFeatureExtractor.addFeatures(featureMap, i);
109+
extractedFeatures.add(featureMap);
110+
}
111+
}
112+
assertThat(extractedFeatures, hasSize(4));
113+
// Should never add features for queries that don't match a document or on documents where the field is missing
114+
for (Map<String, Object> features : extractedFeatures) {
115+
assertThat(features, not(hasKey("matching_none")));
116+
assertThat(features, not(hasKey("matching_missing_field")));
117+
}
118+
// First two only match the text field
119+
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
120+
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
121+
assertThat(extractedFeatures.get(0), not(hasKey("phrase_score")));
122+
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
123+
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
124+
assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f));
125+
126+
// Only matches the range query
127+
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
128+
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
129+
assertThat(extractedFeatures.get(2), not(hasKey("phrase_score")));
130+
131+
// No query matches
132+
assertThat(extractedFeatures.get(3), anEmptyMap());
110133
}
111134
}
112-
assertThat(extractedFeatures, hasSize(4));
113-
// Should never add features for queries that don't match a document or on documents where the field is missing
114-
for (Map<String, Object> features : extractedFeatures) {
115-
assertThat(features, not(hasKey("matching_none")));
116-
assertThat(features, not(hasKey("matching_missing_field")));
117-
}
118-
// First two only match the text field
119-
assertThat(extractedFeatures.get(0), hasEntry("text_score", 1.7135582f));
120-
assertThat(extractedFeatures.get(0), not(hasKey("number_score")));
121-
assertThat(extractedFeatures.get(0), not(hasKey("phrase_score")));
122-
assertThat(extractedFeatures.get(1), hasEntry("text_score", 0.7554128f));
123-
assertThat(extractedFeatures.get(1), not(hasKey("number_score")));
124-
assertThat(extractedFeatures.get(1), hasEntry("phrase_score", 2.468971f));
125-
126-
// Only matches the range query
127-
assertThat(extractedFeatures.get(2), hasEntry("number_score", 1f));
128-
assertThat(extractedFeatures.get(2), not(hasKey("text_score")));
129-
assertThat(extractedFeatures.get(2), not(hasKey("phrase_score")));
130-
// No query matches
131-
assertThat(extractedFeatures.get(3), anEmptyMap());
132-
reader.close();
133-
dir.close();
134135
}
135136

136137
public void testEmptyDisiPriorityQueue() throws IOException {
137-
addDocs(
138-
new String[] { "the quick brown fox", "the slow brown fox", "the grey dog", "yet another string" },
139-
new int[] { 5, 10, 12, 11 }
140-
);
138+
try (var dir = newDirectory()) {
139+
var config = newIndexWriterConfig();
140+
config.setMergePolicy(NoMergePolicy.INSTANCE);
141+
try (var reader = addDocs(
142+
dir,
143+
new String[]{"the quick brown fox", "the slow brown fox", "the grey dog", "yet another string"},
144+
new int[]{5, 10, 12, 11}
145+
)) {
141146

142-
// Scorers returned by weights are null
143-
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
144-
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
147+
var searcher = newSearcher(reader);
148+
searcher.setSimilarity(new ClassicSimilarity());
145149

146-
QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
150+
// Scorers returned by weights are null
151+
List<String> featureNames = randomList(1, 10, ESTestCase::randomIdentifier);
152+
List<Weight> weights = Stream.generate(() -> mock(Weight.class)).limit(featureNames.size()).toList();
147153

148-
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
149-
int maxDoc = leafReaderContext.reader().maxDoc();
150-
featureExtractor.setNextReader(leafReaderContext);
151-
for (int i = 0; i < maxDoc; i++) {
152-
Map<String, Object> featureMap = new HashMap<>();
153-
featureExtractor.addFeatures(featureMap, i);
154-
assertThat(featureMap, anEmptyMap());
154+
QueryFeatureExtractor featureExtractor = new QueryFeatureExtractor(featureNames, weights);
155+
for (LeafReaderContext leafReaderContext : searcher.getLeafContexts()) {
156+
int maxDoc = leafReaderContext.reader().maxDoc();
157+
featureExtractor.setNextReader(leafReaderContext);
158+
for (int i = 0; i < maxDoc; i++) {
159+
Map<String, Object> featureMap = new HashMap<>();
160+
featureExtractor.addFeatures(featureMap, i);
161+
assertThat(featureMap, anEmptyMap());
162+
}
163+
}
155164
}
156165
}
157-
158-
reader.close();
159-
dir.close();
160166
}
161167
}

0 commit comments

Comments
 (0)