Skip to content

Commit 84b6ed7

Browse files
committed
Created PinnedQueryRetriever Builder
1 parent 0b09506 commit 84b6ed7

File tree

1 file changed

+221
-0
lines changed

1 file changed

+221
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.searchbusinessrules.retriever;
9+
10+
import org.apache.lucene.search.ScoreDoc;
11+
import org.elasticsearch.common.ParsingException;
12+
import org.elasticsearch.index.query.QueryBuilder;
13+
import org.elasticsearch.search.builder.SearchSourceBuilder;
14+
import org.elasticsearch.search.rank.RankDoc;
15+
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
16+
import org.elasticsearch.search.retriever.RetrieverBuilder;
17+
import org.elasticsearch.search.retriever.RetrieverBuilderWrapper;
18+
import org.elasticsearch.search.retriever.RetrieverParserContext;
19+
import org.elasticsearch.xcontent.ConstructingObjectParser;
20+
import org.elasticsearch.xcontent.ParseField;
21+
import org.elasticsearch.xcontent.XContentBuilder;
22+
import org.elasticsearch.xcontent.XContentParser;
23+
import org.elasticsearch.xpack.searchbusinessrules.PinnedQueryBuilder;
24+
import org.elasticsearch.xpack.searchbusinessrules.SpecifiedDocument;
25+
26+
import java.io.IOException;
27+
import java.util.ArrayList;
28+
import java.util.List;
29+
import java.util.Objects;
30+
31+
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
32+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
33+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
34+
35+
/**
36+
* A pinned retriever applies pinned documents to the underlying retriever.
37+
* This retriever will rewrite to a PinnedQueryBuilder.
38+
*/
39+
public final class PinnedRetrieverBuilder extends CompoundRetrieverBuilder<PinnedRetrieverBuilder> {
40+
41+
public static final String NAME = "pinned";
42+
43+
public static final ParseField IDS_FIELD = new ParseField("ids");
44+
public static final ParseField DOCS_FIELD = new ParseField("docs");
45+
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
46+
47+
@SuppressWarnings("unchecked")
48+
public static final ConstructingObjectParser<PinnedRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
49+
NAME,
50+
args -> {
51+
List<String> ids = (List<String>) args[0];
52+
List<SpecifiedDocument> docs = (List<SpecifiedDocument>) args[1];
53+
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[2];
54+
int rankWindowSize = args[3] == null ? DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
55+
return new PinnedRetrieverBuilder(ids, docs, retrieverBuilder, rankWindowSize);
56+
}
57+
);
58+
59+
static {
60+
PARSER.declareStringArray(optionalConstructorArg(), IDS_FIELD);
61+
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
62+
String index = p.textOrNull();
63+
String id = p.text();
64+
return new SpecifiedDocument(index, id);
65+
}, DOCS_FIELD);
66+
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
67+
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
68+
c.trackRetrieverUsage(innerRetriever.getName());
69+
return innerRetriever;
70+
}, RETRIEVER_FIELD);
71+
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
72+
RetrieverBuilder.declareBaseParserFields(PARSER);
73+
}
74+
75+
public static PinnedRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
76+
try {
77+
return PARSER.apply(parser, context);
78+
} catch (Exception e) {
79+
throw new ParsingException(parser.getTokenLocation(), e.getMessage(), e);
80+
}
81+
}
82+
83+
private final List<String> ids;
84+
private final List<SpecifiedDocument> docs;
85+
86+
public PinnedRetrieverBuilder(
87+
List<String> ids,
88+
List<SpecifiedDocument> docs,
89+
RetrieverBuilder retrieverBuilder,
90+
int rankWindowSize
91+
) {
92+
super(new ArrayList<>(), rankWindowSize);
93+
this.ids = ids != null ? ids : new ArrayList<>();
94+
this.docs = docs != null ? docs : new ArrayList<>();
95+
addChild(new PinnedRetrieverBuilderWrapper(retrieverBuilder));
96+
}
97+
98+
public PinnedRetrieverBuilder(
99+
List<String> ids,
100+
List<SpecifiedDocument> docs,
101+
List<RetrieverSource> retrieverSource,
102+
int rankWindowSize,
103+
String retrieverName,
104+
List<QueryBuilder> preFilterQueryBuilders
105+
) {
106+
super(retrieverSource, rankWindowSize);
107+
this.ids = ids;
108+
this.docs = docs;
109+
this.retrieverName = retrieverName;
110+
this.preFilterQueryBuilders = preFilterQueryBuilders;
111+
}
112+
113+
@Override
114+
public String getName() {
115+
return NAME;
116+
}
117+
118+
public int rankWindowSize() {
119+
return rankWindowSize;
120+
}
121+
122+
/**
123+
* Creates a PinnedQueryBuilder with the appropriate pinned documents.
124+
* Prioritizes docs over ids if both are present.
125+
*
126+
* @param baseQuery the base query to pin documents to
127+
* @return a PinnedQueryBuilder or the original query if no pinned documents
128+
*/
129+
private QueryBuilder createPinnedQuery(QueryBuilder baseQuery) {
130+
if (!docs.isEmpty()) {
131+
return new PinnedQueryBuilder(baseQuery, docs.toArray(new SpecifiedDocument[0]));
132+
} else if (!ids.isEmpty()) {
133+
return new PinnedQueryBuilder(baseQuery, ids.toArray(new String[0]));
134+
} else {
135+
return baseQuery;
136+
}
137+
}
138+
139+
@Override
140+
protected SearchSourceBuilder finalizeSourceBuilder(SearchSourceBuilder source) {
141+
source.query(createPinnedQuery(source.query()));
142+
return source;
143+
}
144+
145+
@Override
146+
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
147+
if (ids != null && !ids.isEmpty()) {
148+
builder.array(IDS_FIELD.getPreferredName(), ids.toArray());
149+
}
150+
if (docs != null && !docs.isEmpty()) {
151+
builder.startArray(DOCS_FIELD.getPreferredName());
152+
for (SpecifiedDocument doc : docs) {
153+
builder.value(doc);
154+
}
155+
builder.endArray();
156+
}
157+
builder.field(RETRIEVER_FIELD.getPreferredName(), innerRetrievers.getFirst().retriever());
158+
builder.field(RANK_WINDOW_SIZE_FIELD.getPreferredName(), rankWindowSize);
159+
}
160+
161+
@Override
162+
protected PinnedRetrieverBuilder clone(List<RetrieverSource> newChildRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
163+
return new PinnedRetrieverBuilder(
164+
ids,
165+
docs,
166+
newChildRetrievers,
167+
rankWindowSize,
168+
retrieverName,
169+
newPreFilterQueryBuilders
170+
);
171+
}
172+
173+
@Override
174+
protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, boolean explain) {
175+
assert rankResults.size() == 1;
176+
ScoreDoc[] scoreDocs = rankResults.getFirst();
177+
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
178+
for (int i = 0; i < scoreDocs.length; i++) {
179+
ScoreDoc scoreDoc = scoreDocs[i];
180+
rankDocs[i] = new RankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex);
181+
rankDocs[i].rank = i + 1;
182+
}
183+
return rankDocs;
184+
}
185+
186+
@Override
187+
public boolean doEquals(Object o) {
188+
PinnedRetrieverBuilder that = (PinnedRetrieverBuilder) o;
189+
return super.doEquals(o) && Objects.equals(ids, that.ids) && Objects.equals(docs, that.docs);
190+
}
191+
192+
@Override
193+
public int doHashCode() {
194+
return Objects.hash(super.doHashCode(), ids, docs);
195+
}
196+
197+
/**
198+
* We need to wrap the PinnedRetrieverBuilder in order to ensure that the top docs query that is generated
199+
* by this retriever correctly generates and executes a Pinned query.
200+
*/
201+
class PinnedRetrieverBuilderWrapper extends RetrieverBuilderWrapper<PinnedRetrieverBuilderWrapper> {
202+
protected PinnedRetrieverBuilderWrapper(RetrieverBuilder in) {
203+
super(in);
204+
}
205+
206+
@Override
207+
protected PinnedRetrieverBuilderWrapper clone(RetrieverBuilder in) {
208+
return new PinnedRetrieverBuilderWrapper(in);
209+
}
210+
211+
@Override
212+
public QueryBuilder topDocsQuery() {
213+
return createPinnedQuery(in.topDocsQuery());
214+
}
215+
216+
@Override
217+
public QueryBuilder explainQuery() {
218+
return createPinnedQuery(in.explainQuery());
219+
}
220+
}
221+
}

0 commit comments

Comments
 (0)