Skip to content

Commit f1a5762

Browse files
committed
Initial refactor to compound retriever
1 parent a6a06d1 commit f1a5762

File tree

1 file changed

+93
-28
lines changed

1 file changed

+93
-28
lines changed

x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/rules/retriever/QueryRuleRetrieverBuilder.java

Lines changed: 93 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,20 @@
99

1010
package org.elasticsearch.xpack.application.rules.retriever;
1111

12+
import org.apache.lucene.search.ScoreDoc;
1213
import org.elasticsearch.common.ParsingException;
1314
import org.elasticsearch.features.NodeFeature;
15+
import org.elasticsearch.index.query.BoolQueryBuilder;
1416
import org.elasticsearch.index.query.QueryBuilder;
1517
import org.elasticsearch.license.LicenseUtils;
18+
import org.elasticsearch.search.builder.PointInTimeBuilder;
1619
import org.elasticsearch.search.builder.SearchSourceBuilder;
20+
import org.elasticsearch.search.fetch.StoredFieldsContext;
21+
import org.elasticsearch.search.rank.RankDoc;
22+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1723
import org.elasticsearch.search.retriever.RetrieverBuilder;
1824
import org.elasticsearch.search.retriever.RetrieverParserContext;
25+
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
1926
import org.elasticsearch.xcontent.ConstructingObjectParser;
2027
import org.elasticsearch.xcontent.ParseField;
2128
import org.elasticsearch.xcontent.XContentBuilder;
@@ -25,23 +32,27 @@
2532
import org.elasticsearch.xpack.core.XPackPlugin;
2633

2734
import java.io.IOException;
35+
import java.util.ArrayList;
2836
import java.util.List;
2937
import java.util.Map;
3038
import java.util.Objects;
3139

40+
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
3241
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
42+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
3343

3444
/**
3545
* A query rule retriever applies query rules defined in one or more rulesets to the underlying retriever.
3646
*/
37-
public final class QueryRuleRetrieverBuilder extends RetrieverBuilder {
47+
public final class QueryRuleRetrieverBuilder extends CompoundRetrieverBuilder<QueryRuleRetrieverBuilder> {
3848

3949
public static final String NAME = "rule";
4050
public static final NodeFeature QUERY_RULE_RETRIEVERS_SUPPORTED = new NodeFeature("query_rule_retriever_supported");
4151

4252
public static final ParseField RULESET_IDS_FIELD = new ParseField("ruleset_ids");
4353
public static final ParseField MATCH_CRITERIA_FIELD = new ParseField("match_criteria");
4454
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
55+
public static final ParseField RANK_WINDOW_SIZE_FIELD = new ParseField("rank_window_size");
4556

4657
@SuppressWarnings("unchecked")
4758
public static final ConstructingObjectParser<QueryRuleRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
@@ -50,14 +61,16 @@ public final class QueryRuleRetrieverBuilder extends RetrieverBuilder {
5061
List<String> rulesetIds = (List<String>) args[0];
5162
Map<String, Object> matchCriteria = (Map<String, Object>) args[1];
5263
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[2];
53-
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, retrieverBuilder);
64+
int rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
65+
return new QueryRuleRetrieverBuilder(rulesetIds, matchCriteria, retrieverBuilder, rankWindowSize);
5466
}
5567
);
5668

5769
static {
5870
PARSER.declareStringArray(constructorArg(), RULESET_IDS_FIELD);
5971
PARSER.declareObject(constructorArg(), (p, c) -> p.map(), MATCH_CRITERIA_FIELD);
6072
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> p.namedObject(RetrieverBuilder.class, n, c), RETRIEVER_FIELD);
73+
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
6174
RetrieverBuilder.declareBaseParserFields(NAME, PARSER);
6275
}
6376

@@ -77,36 +90,66 @@ public static QueryRuleRetrieverBuilder fromXContent(XContentParser parser, Retr
7790

7891
private final List<String> rulesetIds;
7992
private final Map<String, Object> matchCriteria;
80-
private final RetrieverBuilder retrieverBuilder;
8193

82-
public QueryRuleRetrieverBuilder(List<String> rulesetIds, Map<String, Object> matchCriteria, RetrieverBuilder retrieverBuilder) {
94+
public QueryRuleRetrieverBuilder(
95+
List<String> rulesetIds,
96+
Map<String, Object> matchCriteria,
97+
RetrieverBuilder retrieverBuilder,
98+
int rankWindowSize
99+
) {
100+
super(List.of(new RetrieverSource(retrieverBuilder, null)), rankWindowSize);
83101
this.rulesetIds = rulesetIds;
84102
this.matchCriteria = matchCriteria;
85-
this.retrieverBuilder = retrieverBuilder;
86103
}
87104

88-
@Override
89-
public String getName() {
90-
return NAME;
105+
public QueryRuleRetrieverBuilder(
106+
List<String> rulesetIds,
107+
Map<String, Object> matchCriteria,
108+
List<RetrieverSource> retrieverSource,
109+
int rankWindowSize,
110+
String retrieverName,
111+
List<QueryBuilder> preFilterQueryBuilders
112+
) {
113+
super(retrieverSource, rankWindowSize);
114+
this.rulesetIds = rulesetIds;
115+
this.matchCriteria = matchCriteria;
116+
this.retrieverName = retrieverName;
117+
this.preFilterQueryBuilders = new ArrayList<>(preFilterQueryBuilders);
91118
}
92119

93120
@Override
94-
public QueryBuilder topDocsQuery() {
95-
assert rankDocs != null : "{rankDocs} should have been materialized at this point";
96-
97-
// TODO is this correct?
98-
return retrieverBuilder.topDocsQuery();
121+
public String getName() {
122+
return NAME;
99123
}
100124

101125
@Override
102-
public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed) {
126+
protected SearchSourceBuilder createSearchSourceBuilder(PointInTimeBuilder pit, RetrieverBuilder retrieverBuilder) {
127+
var sourceBuilder = new SearchSourceBuilder().pointInTimeBuilder(pit)
128+
.trackTotalHits(false)
129+
.storedFields(new StoredFieldsContext(false))
130+
.size(rankWindowSize);
131+
if (preFilterQueryBuilders.isEmpty() == false) {
132+
retrieverBuilder.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
133+
}
134+
retrieverBuilder.extractToSearchSourceBuilder(sourceBuilder, true);
103135

104-
// TODO throw if compoundUsed is true?
136+
QueryBuilder query = sourceBuilder.query();
137+
if (query != null && query instanceof RuleQueryBuilder == false) {
138+
QueryBuilder organicQuery = query;
139+
query = new RuleQueryBuilder(organicQuery, matchCriteria, rulesetIds);
140+
}
105141

106-
QueryBuilder organicQuery = retrieverBuilder.topDocsQuery();
107-
QueryBuilder queryBuilder = new RuleQueryBuilder(organicQuery, matchCriteria, rulesetIds);
142+
// apply the pre-filters
143+
if (preFilterQueryBuilders.size() > 0) {
144+
BoolQueryBuilder newQuery = new BoolQueryBuilder();
145+
if (query != null) {
146+
newQuery.must(query);
147+
}
148+
preFilterQueryBuilders.forEach(newQuery::filter);
149+
sourceBuilder.query(newQuery);
150+
}
108151

109-
searchSourceBuilder.query(queryBuilder);
152+
return sourceBuilder;
110153
}
111154

112155
@Override
@@ -115,24 +158,46 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
115158
builder.startObject(MATCH_CRITERIA_FIELD.getPreferredName());
116159
builder.mapContents(matchCriteria);
117160
builder.endObject();
118-
builder.startObject("retriever");
119-
builder.startObject();
120-
builder.field(retrieverBuilder.getName());
121-
retrieverBuilder.toXContent(builder, params);
122-
builder.endObject();
123-
builder.endObject();
161+
}
162+
163+
@Override
164+
protected QueryRuleRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers) {
165+
return new QueryRuleRetrieverBuilder(
166+
rulesetIds,
167+
matchCriteria,
168+
newChildRetrievers,
169+
rankWindowSize,
170+
retrieverName,
171+
preFilterQueryBuilders
172+
);
173+
}
174+
175+
@Override
176+
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults) {
177+
assert rankResults.size() == 1;
178+
ScoreDoc[] scoreDocs = rankResults.getFirst();
179+
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
180+
for (int i = 0; i < scoreDocs.length; i++) {
181+
ScoreDoc scoreDoc = scoreDocs[i];
182+
rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
183+
}
184+
return rankDocs;
185+
}
186+
187+
@Override
188+
public QueryBuilder explainQuery() {
189+
// the original matching set of the QueryRuleRetriever retriever is specified by its nested retriever
190+
return new RankDocsQueryBuilder(rankDocs, new QueryBuilder[] { innerRetrievers.getFirst().retriever().explainQuery() }, true);
124191
}
125192

126193
@Override
127194
public boolean doEquals(Object o) {
128195
QueryRuleRetrieverBuilder that = (QueryRuleRetrieverBuilder) o;
129-
return Objects.equals(rulesetIds, that.rulesetIds)
130-
&& Objects.equals(matchCriteria, that.matchCriteria)
131-
&& Objects.equals(retrieverBuilder, that.retrieverBuilder);
196+
return super.doEquals(o) && Objects.equals(rulesetIds, that.rulesetIds) && Objects.equals(matchCriteria, that.matchCriteria);
132197
}
133198

134199
@Override
135200
public int doHashCode() {
136-
return Objects.hash(rulesetIds, matchCriteria, retrieverBuilder);
201+
return Objects.hash(super.doHashCode(), rulesetIds, matchCriteria);
137202
}
138203
}

0 commit comments

Comments
 (0)