8
8
package org .elasticsearch .xpack .rank .rrf ;
9
9
10
10
import org .apache .lucene .search .ScoreDoc ;
11
+ import org .elasticsearch .action .ActionRequestValidationException ;
12
+ import org .elasticsearch .action .ResolvedIndices ;
11
13
import org .elasticsearch .common .ParsingException ;
12
14
import org .elasticsearch .common .util .Maps ;
13
15
import org .elasticsearch .features .NodeFeature ;
16
+ import org .elasticsearch .index .query .MatchNoneQueryBuilder ;
14
17
import org .elasticsearch .index .query .QueryBuilder ;
18
+ import org .elasticsearch .index .query .QueryRewriteContext ;
15
19
import org .elasticsearch .license .LicenseUtils ;
20
+ import org .elasticsearch .search .builder .SearchSourceBuilder ;
16
21
import org .elasticsearch .search .rank .RankBuilder ;
17
22
import org .elasticsearch .search .rank .RankDoc ;
18
23
import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
19
24
import org .elasticsearch .search .retriever .RetrieverBuilder ;
20
25
import org .elasticsearch .search .retriever .RetrieverParserContext ;
26
+ import org .elasticsearch .search .retriever .StandardRetrieverBuilder ;
21
27
import org .elasticsearch .xcontent .ConstructingObjectParser ;
22
28
import org .elasticsearch .xcontent .ParseField ;
23
29
import org .elasticsearch .xcontent .XContentBuilder ;
24
30
import org .elasticsearch .xcontent .XContentParser ;
25
31
import org .elasticsearch .xpack .core .XPackPlugin ;
32
+ import org .elasticsearch .xpack .rank .MultiFieldsInnerRetrieverUtils ;
26
33
27
34
import java .io .IOException ;
28
35
import java .util .ArrayList ;
31
38
import java .util .Map ;
32
39
import java .util .Objects ;
33
40
34
- import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
35
41
import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
36
42
37
43
/**
42
48
* formula.
43
49
*/
44
50
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder <RRFRetrieverBuilder > {
51
+ public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature ("rrf_retriever.multi_fields_query_format_support" );
45
52
46
53
public static final String NAME = "rrf" ;
47
54
public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature ("rrf_retriever_supported" , true );
48
55
public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature ("rrf_retriever_composition_supported" , true );
49
56
50
57
public static final ParseField RETRIEVERS_FIELD = new ParseField ("retrievers" );
51
58
public static final ParseField RANK_CONSTANT_FIELD = new ParseField ("rank_constant" );
59
+ public static final ParseField FIELDS_FIELD = new ParseField ("fields" );
60
+ public static final ParseField QUERY_FIELD = new ParseField ("query" );
52
61
53
62
public static final int DEFAULT_RANK_CONSTANT = 60 ;
54
63
@ SuppressWarnings ("unchecked" )
@@ -57,22 +66,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
57
66
false ,
58
67
args -> {
59
68
List <RetrieverBuilder > childRetrievers = (List <RetrieverBuilder >) args [0 ];
60
- List <RetrieverSource > innerRetrievers = childRetrievers .stream ().map (RetrieverSource ::from ).toList ();
61
- int rankWindowSize = args [1 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [1 ];
62
- int rankConstant = args [2 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [2 ];
63
- return new RRFRetrieverBuilder (innerRetrievers , rankWindowSize , rankConstant );
69
+ List <String > fields = (List <String >) args [1 ];
70
+ String query = (String ) args [2 ];
71
+ int rankWindowSize = args [3 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [3 ];
72
+ int rankConstant = args [4 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [4 ];
73
+
74
+ List <RetrieverSource > innerRetrievers = childRetrievers != null
75
+ ? childRetrievers .stream ().map (RetrieverSource ::from ).toList ()
76
+ : List .of ();
77
+ return new RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant );
64
78
}
65
79
);
66
80
67
81
static {
68
- PARSER .declareObjectArray (constructorArg (), (p , c ) -> {
82
+ PARSER .declareObjectArray (optionalConstructorArg (), (p , c ) -> {
69
83
p .nextToken ();
70
84
String name = p .currentName ();
71
85
RetrieverBuilder retrieverBuilder = p .namedObject (RetrieverBuilder .class , name , c );
72
86
c .trackRetrieverUsage (retrieverBuilder .getName ());
73
87
p .nextToken ();
74
88
return retrieverBuilder ;
75
89
}, RETRIEVERS_FIELD );
90
+ PARSER .declareStringArray (optionalConstructorArg (), FIELDS_FIELD );
91
+ PARSER .declareString (optionalConstructorArg (), QUERY_FIELD );
76
92
PARSER .declareInt (optionalConstructorArg (), RANK_WINDOW_SIZE_FIELD );
77
93
PARSER .declareInt (optionalConstructorArg (), RANK_CONSTANT_FIELD );
78
94
RetrieverBuilder .declareBaseParserFields (NAME , PARSER );
@@ -91,25 +107,60 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
91
107
return PARSER .apply (parser , context );
92
108
}
93
109
110
+ private final List <String > fields ;
111
+ private final String query ;
94
112
private final int rankConstant ;
95
113
96
- public RRFRetrieverBuilder (int rankWindowSize , int rankConstant ) {
97
- this (new ArrayList <>() , rankWindowSize , rankConstant );
114
+ public RRFRetrieverBuilder (List < RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
115
+ this (childRetrievers , null , null , rankWindowSize , rankConstant );
98
116
}
99
117
100
- RRFRetrieverBuilder (List <RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
101
- super (childRetrievers , rankWindowSize );
118
+ public RRFRetrieverBuilder (
119
+ List <RetrieverSource > childRetrievers ,
120
+ List <String > fields ,
121
+ String query ,
122
+ int rankWindowSize ,
123
+ int rankConstant
124
+ ) {
125
+ // Use a mutable list for childRetrievers so that we can use addChild
126
+ super (childRetrievers == null ? new ArrayList <>() : new ArrayList <>(childRetrievers ), rankWindowSize );
127
+ this .fields = fields == null ? List .of () : List .copyOf (fields );
128
+ this .query = query ;
102
129
this .rankConstant = rankConstant ;
103
130
}
104
131
132
+ public int rankConstant () {
133
+ return rankConstant ;
134
+ }
135
+
105
136
@ Override
106
137
public String getName () {
107
138
return NAME ;
108
139
}
109
140
141
+ @ Override
142
+ public ActionRequestValidationException validate (
143
+ SearchSourceBuilder source ,
144
+ ActionRequestValidationException validationException ,
145
+ boolean isScroll ,
146
+ boolean allowPartialSearchResults
147
+ ) {
148
+ validationException = super .validate (source , validationException , isScroll , allowPartialSearchResults );
149
+ return MultiFieldsInnerRetrieverUtils .validateParams (
150
+ innerRetrievers ,
151
+ fields ,
152
+ query ,
153
+ getName (),
154
+ RETRIEVERS_FIELD .getPreferredName (),
155
+ FIELDS_FIELD .getPreferredName (),
156
+ QUERY_FIELD .getPreferredName (),
157
+ validationException
158
+ );
159
+ }
160
+
110
161
@ Override
111
162
protected RRFRetrieverBuilder clone (List <RetrieverSource > newRetrievers , List <QueryBuilder > newPreFilterQueryBuilders ) {
112
- RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .rankWindowSize , this .rankConstant );
163
+ RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .fields , this . query , this . rankWindowSize , this .rankConstant );
113
164
clone .preFilterQueryBuilders = newPreFilterQueryBuilders ;
114
165
clone .retrieverName = retrieverName ;
115
166
return clone ;
@@ -172,17 +223,72 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
172
223
return topResults ;
173
224
}
174
225
226
+ @ Override
227
+ protected RetrieverBuilder doRewrite (QueryRewriteContext ctx ) {
228
+ RetrieverBuilder rewritten = this ;
229
+
230
+ ResolvedIndices resolvedIndices = ctx .getResolvedIndices ();
231
+ if (resolvedIndices != null && query != null ) {
232
+ // TODO: Refactor duplicate code
233
+ // Using the multi-fields query format
234
+ var localIndicesMetadata = resolvedIndices .getConcreteLocalIndicesMetadata ();
235
+ if (localIndicesMetadata .size () > 1 ) {
236
+ throw new IllegalArgumentException (
237
+ "[" + NAME + "] cannot specify [" + QUERY_FIELD .getPreferredName () + "] when querying multiple indices"
238
+ );
239
+ } else if (resolvedIndices .getRemoteClusterIndices ().isEmpty () == false ) {
240
+ throw new IllegalArgumentException (
241
+ "[" + NAME + "] cannot specify [" + QUERY_FIELD .getPreferredName () + "] when querying remote indices"
242
+ );
243
+ }
244
+
245
+ List <RetrieverSource > fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils .generateInnerRetrievers (
246
+ fields ,
247
+ query ,
248
+ localIndicesMetadata .values (),
249
+ r -> {
250
+ List <RetrieverSource > retrievers = r .stream ()
251
+ .map (MultiFieldsInnerRetrieverUtils .WeightedRetrieverSource ::retrieverSource )
252
+ .toList ();
253
+ return new RRFRetrieverBuilder (retrievers , rankWindowSize , rankConstant );
254
+ },
255
+ w -> {
256
+ if (w != 1.0f ) {
257
+ throw new IllegalArgumentException (
258
+ "[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD .getPreferredName () + "]"
259
+ );
260
+ }
261
+ }
262
+ ).stream ().map (RetrieverSource ::from ).toList ();
263
+
264
+ if (fieldsInnerRetrievers .isEmpty () == false ) {
265
+ // TODO: This is a incomplete solution as it does not address other incomplete copy issues
266
+ // (such as dropping the retriever name and min score)
267
+ rewritten = new RRFRetrieverBuilder (fieldsInnerRetrievers , rankWindowSize , rankConstant );
268
+ rewritten .getPreFilterQueryBuilders ().addAll (preFilterQueryBuilders );
269
+ } else {
270
+ // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
271
+ rewritten = new StandardRetrieverBuilder (new MatchNoneQueryBuilder ());
272
+ }
273
+ }
274
+
275
+ return rewritten ;
276
+ }
277
+
175
278
// ---- FOR TESTING XCONTENT PARSING ----
176
279
177
280
@ Override
178
281
public boolean doEquals (Object o ) {
179
282
RRFRetrieverBuilder that = (RRFRetrieverBuilder ) o ;
180
- return super .doEquals (o ) && rankConstant == that .rankConstant ;
283
+ return super .doEquals (o )
284
+ && Objects .equals (fields , that .fields )
285
+ && Objects .equals (query , that .query )
286
+ && rankConstant == that .rankConstant ;
181
287
}
182
288
183
289
@ Override
184
290
public int doHashCode () {
185
- return Objects .hash (super .doHashCode (), rankConstant );
291
+ return Objects .hash (super .doHashCode (), fields , query , rankConstant );
186
292
}
187
293
188
294
@ Override
@@ -196,6 +302,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
196
302
builder .endArray ();
197
303
}
198
304
305
+ if (fields .isEmpty () == false ) {
306
+ builder .startArray (FIELDS_FIELD .getPreferredName ());
307
+ for (String field : fields ) {
308
+ builder .value (field );
309
+ }
310
+ builder .endArray ();
311
+ }
312
+ if (query != null ) {
313
+ builder .field (QUERY_FIELD .getPreferredName (), query );
314
+ }
315
+
199
316
builder .field (RANK_WINDOW_SIZE_FIELD .getPreferredName (), rankWindowSize );
200
317
builder .field (RANK_CONSTANT_FIELD .getPreferredName (), rankConstant );
201
318
}
0 commit comments