Skip to content

Commit c84028e

Browse files
committed
Included PinnedRankDoc to enhance the explain query, validations for inputs anfd sorting. Fixed the tests after these changes
1 parent a75926c commit c84028e

File tree

3 files changed

+68
-84
lines changed

3 files changed

+68
-84
lines changed

x-pack/plugin/search-business-rules/src/main/java/org/elasticsearch/xpack/searchbusinessrules/retriever/PinnedRankDoc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ public String getPinnedBy() {
2424
public String toString() {
2525
return super.toString() + ", isPinned=" + isPinned + ", pinnedBy=" + pinnedBy;
2626
}
27-
}
27+
}

x-pack/plugin/search-business-rules/src/main/java/org/elasticsearch/xpack/searchbusinessrules/retriever/PinnedRetrieverBuilder.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@
99

1010
import org.apache.lucene.search.ScoreDoc;
1111
import org.elasticsearch.common.ParsingException;
12+
import org.elasticsearch.index.query.MatchAllQueryBuilder;
1213
import org.elasticsearch.index.query.QueryBuilder;
1314
import org.elasticsearch.search.builder.SearchSourceBuilder;
1415
import org.elasticsearch.search.rank.RankDoc;
1516
import org.elasticsearch.search.retriever.CompoundRetrieverBuilder;
1617
import org.elasticsearch.search.retriever.RetrieverBuilder;
1718
import org.elasticsearch.search.retriever.RetrieverBuilderWrapper;
1819
import org.elasticsearch.search.retriever.RetrieverParserContext;
20+
import org.elasticsearch.search.sort.ScoreSortBuilder;
21+
import org.elasticsearch.search.sort.SortBuilder;
1922
import org.elasticsearch.xcontent.ConstructingObjectParser;
2023
import org.elasticsearch.xcontent.ParseField;
2124
import org.elasticsearch.xcontent.XContentBuilder;
@@ -80,13 +83,14 @@ public static PinnedRetrieverBuilder fromXContent(XContentParser parser, Retriev
8083
private final List<SpecifiedDocument> docs;
8184

8285
private void validateIdsAndDocs(List<String> ids, List<SpecifiedDocument> docs) {
83-
if (ids != null && docs != null && ids.isEmpty() == false && docs.isEmpty() == false) {
86+
if ((ids != null && ids.isEmpty() == false) && (docs != null && docs.isEmpty() == false)) {
8487
throw new IllegalArgumentException("Both 'ids' and 'docs' cannot be specified at the same time");
8588
}
8689
}
8790

8891
private void validateSort(SearchSourceBuilder source) {
89-
if (source.sorts() != null && source.sorts().isEmpty() == false) {
92+
List<SortBuilder<?>> sorts = source.sorts();
93+
if (sorts != null && sorts.stream().anyMatch(sort -> sort instanceof ScoreSortBuilder == false)) {
9094
throw new IllegalArgumentException("Pinned retriever only supports sorting by score. Custom sorting is not allowed.");
9195
}
9296
}
@@ -132,6 +136,9 @@ public int rankWindowSize() {
132136
* @return a PinnedQueryBuilder or the original query if no pinned documents
133137
*/
134138
private QueryBuilder createPinnedQuery(QueryBuilder baseQuery) {
139+
if (baseQuery == null) {
140+
baseQuery = new MatchAllQueryBuilder();
141+
}
135142
if (docs.isEmpty() == false) {
136143
return new PinnedQueryBuilder(baseQuery, docs.toArray(new SpecifiedDocument[0]));
137144
} else if (ids.isEmpty() == false) {
@@ -176,8 +183,8 @@ protected RankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults, b
176183
RankDoc[] rankDocs = new RankDoc[scoreDocs.length];
177184
for (int i = 0; i < scoreDocs.length; i++) {
178185
ScoreDoc scoreDoc = scoreDocs[i];
179-
boolean isPinned = docs.stream().anyMatch(doc -> doc.id().equals(String.valueOf(scoreDoc.doc))) ||
180-
ids.contains(String.valueOf(scoreDoc.doc));
186+
boolean isPinned = docs.stream().anyMatch(doc -> doc.id().equals(String.valueOf(scoreDoc.doc)))
187+
|| ids.contains(String.valueOf(scoreDoc.doc));
181188
String pinnedBy = docs.stream().anyMatch(doc -> doc.id().equals(String.valueOf(scoreDoc.doc))) ? "docs" : "ids";
182189
rankDocs[i] = new PinnedRankDoc(scoreDoc.doc, scoreDoc.score, scoreDoc.shardIndex, isPinned, pinnedBy);
183190
rankDocs[i].rank = i + 1;

x-pack/plugin/search-business-rules/src/test/java/org/elasticsearch/xpack/searchbusinessrules/retriever/PinnedRetrieverBuilderTests.java

Lines changed: 56 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@
2727
import java.io.IOException;
2828
import java.util.ArrayList;
2929
import java.util.List;
30-
import java.util.stream.IntStream;
3130
import java.util.stream.Collectors;
31+
import java.util.stream.IntStream;
3232

3333
import static org.elasticsearch.search.rank.RankBuilder.DEFAULT_RANK_WINDOW_SIZE;
3434
import static org.hamcrest.Matchers.equalTo;
3535
import static org.hamcrest.Matchers.instanceOf;
36-
import static org.hamcrest.Matchers.containsString;
3736

3837
public class PinnedRetrieverBuilderTests extends AbstractXContentTestCase<PinnedRetrieverBuilder> {
3938

@@ -54,18 +53,13 @@ public static PinnedRetrieverBuilder createRandomPinnedRetrieverBuilder() {
5453
boolean useIds = randomBoolean();
5554
int numItems = randomIntBetween(1, 5);
5655

57-
List<String> ids = useIds ?
58-
IntStream.range(0, numItems)
59-
.mapToObj(i -> randomAlphaOfLengthBetween(5, 10))
60-
.collect(Collectors.toList()) :
61-
new ArrayList<>();
62-
List<SpecifiedDocument> docs = useIds ?
63-
new ArrayList<>() :
64-
IntStream.range(0, numItems)
65-
.mapToObj(i -> new SpecifiedDocument(
66-
randomAlphaOfLengthBetween(5, 10),
67-
randomAlphaOfLengthBetween(5, 10)
68-
))
56+
List<String> ids = useIds
57+
? IntStream.range(0, numItems).mapToObj(i -> randomAlphaOfLengthBetween(5, 10)).collect(Collectors.toList())
58+
: new ArrayList<>();
59+
List<SpecifiedDocument> docs = useIds
60+
? new ArrayList<>()
61+
: IntStream.range(0, numItems)
62+
.mapToObj(i -> new SpecifiedDocument(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLengthBetween(5, 10)))
6963
.collect(Collectors.toList());
7064
return new PinnedRetrieverBuilder(ids, docs, TestRetrieverBuilder.createRandomTestRetrieverBuilder(), randomIntBetween(1, 100));
7165
}
@@ -101,61 +95,6 @@ protected NamedXContentRegistry xContentRegistry() {
10195
return new NamedXContentRegistry(entries);
10296
}
10397

104-
public void testValidation() {
105-
List<String> ids = List.of("id1", "id2");
106-
List<SpecifiedDocument> docs = List.of(
107-
new SpecifiedDocument("id3", "index3"),
108-
new SpecifiedDocument("id4", "index4")
109-
);
110-
111-
IllegalArgumentException e = expectThrows(
112-
IllegalArgumentException.class,
113-
() -> new PinnedRetrieverBuilder(ids, docs, TestRetrieverBuilder.createRandomTestRetrieverBuilder(), 10)
114-
);
115-
assertThat(e.getMessage(), equalTo("Both 'ids' and 'docs' cannot be specified at the same time"));
116-
117-
e = expectThrows(
118-
IllegalArgumentException.class,
119-
() -> new PinnedRetrieverBuilder(new ArrayList<>(), new ArrayList<>(), TestRetrieverBuilder.createRandomTestRetrieverBuilder(), 10)
120-
);
121-
assertThat(e.getMessage(), equalTo("Either 'ids' or 'docs' must be specified"));
122-
123-
assertNotNull(new PinnedRetrieverBuilder(ids, new ArrayList<>(), TestRetrieverBuilder.createRandomTestRetrieverBuilder(), 10));
124-
assertNotNull(new PinnedRetrieverBuilder(new ArrayList<>(), docs, TestRetrieverBuilder.createRandomTestRetrieverBuilder(), 10));
125-
}
126-
127-
public void testValidateSort() {
128-
PinnedRetrieverBuilder builder = createRandomPinnedRetrieverBuilder();
129-
130-
// Test empty sort is allowed
131-
final SearchSourceBuilder emptySource = new SearchSourceBuilder();
132-
builder.finalizeSourceBuilder(emptySource);
133-
134-
// Test score sort is allowed
135-
final SearchSourceBuilder scoreSource = new SearchSourceBuilder();
136-
scoreSource.sort("_score");
137-
builder.finalizeSourceBuilder(scoreSource);
138-
139-
// Test custom sort is not allowed
140-
final SearchSourceBuilder customSortSource = new SearchSourceBuilder();
141-
customSortSource.sort("field");
142-
IllegalArgumentException e = expectThrows(
143-
IllegalArgumentException.class,
144-
() -> builder.finalizeSourceBuilder(customSortSource)
145-
);
146-
assertThat(e.getMessage(), equalTo("Pinned retriever only supports sorting by score. Custom sorting is not allowed."));
147-
148-
// Test multiple sorts including custom sort is not allowed
149-
final SearchSourceBuilder multipleSortsSource = new SearchSourceBuilder();
150-
multipleSortsSource.sort("_score");
151-
multipleSortsSource.sort("field");
152-
e = expectThrows(
153-
IllegalArgumentException.class,
154-
() -> builder.finalizeSourceBuilder(multipleSortsSource)
155-
);
156-
assertThat(e.getMessage(), equalTo("Pinned retriever only supports sorting by score. Custom sorting is not allowed."));
157-
}
158-
15998
public void testParserDefaults() throws IOException {
16099
// Inner retriever content only sent to parser
161100
String json = """
@@ -187,16 +126,6 @@ public void testPinnedRetrieverParsing() throws IOException {
187126
"id1",
188127
"id2"
189128
],
190-
"docs": [
191-
{
192-
"_index": "index1",
193-
"_id": "doc1"
194-
},
195-
{
196-
"_index": "index2",
197-
"_id": "doc2"
198-
}
199-
],
200129
"rank_window_size": 100,
201130
"_name": "my_pinned_retriever"
202131
}
@@ -221,4 +150,52 @@ public void testPinnedRetrieverParsing() throws IOException {
221150
}
222151
}
223152
}
153+
154+
public void testValidation() {
155+
expectThrows(IllegalArgumentException.class, () -> {
156+
new PinnedRetrieverBuilder(
157+
List.of("id1"),
158+
List.of(new SpecifiedDocument("id2", "index")),
159+
new TestRetrieverBuilder("test"),
160+
DEFAULT_RANK_WINDOW_SIZE
161+
);
162+
});
163+
164+
PinnedRetrieverBuilder builder = new PinnedRetrieverBuilder(
165+
List.of(),
166+
List.of(),
167+
new TestRetrieverBuilder("test"),
168+
DEFAULT_RANK_WINDOW_SIZE
169+
);
170+
assertNotNull(builder);
171+
}
172+
173+
public void testValidateSort() {
174+
PinnedRetrieverBuilder builder = new PinnedRetrieverBuilder(
175+
List.of("id1"),
176+
List.of(),
177+
new TestRetrieverBuilder("test"),
178+
DEFAULT_RANK_WINDOW_SIZE
179+
);
180+
181+
SearchSourceBuilder emptySource = new SearchSourceBuilder();
182+
builder.finalizeSourceBuilder(emptySource);
183+
assertThat(emptySource.sorts(), equalTo(null));
184+
185+
SearchSourceBuilder scoreSource = new SearchSourceBuilder();
186+
scoreSource.sort("_score");
187+
builder.finalizeSourceBuilder(scoreSource);
188+
assertThat(scoreSource.sorts().size(), equalTo(1));
189+
190+
SearchSourceBuilder customSortSource = new SearchSourceBuilder();
191+
customSortSource.sort("field1");
192+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> builder.finalizeSourceBuilder(customSortSource));
193+
assertThat(e.getMessage(), equalTo("Pinned retriever only supports sorting by score. Custom sorting is not allowed."));
194+
195+
SearchSourceBuilder multipleSortsSource = new SearchSourceBuilder();
196+
multipleSortsSource.sort("_score");
197+
multipleSortsSource.sort("field1");
198+
e = expectThrows(IllegalArgumentException.class, () -> builder.finalizeSourceBuilder(multipleSortsSource));
199+
assertThat(e.getMessage(), equalTo("Pinned retriever only supports sorting by score. Custom sorting is not allowed."));
200+
}
224201
}

0 commit comments

Comments
 (0)