8
8
package org .elasticsearch .xpack .rank .rrf ;
9
9
10
10
import org .apache .lucene .search .ScoreDoc ;
11
+ import org .elasticsearch .action .ActionRequestValidationException ;
12
+ import org .elasticsearch .action .ResolvedIndices ;
11
13
import org .elasticsearch .common .util .Maps ;
14
+ import org .elasticsearch .features .NodeFeature ;
15
+ import org .elasticsearch .index .query .MatchNoneQueryBuilder ;
12
16
import org .elasticsearch .index .query .QueryBuilder ;
17
+ import org .elasticsearch .index .query .QueryRewriteContext ;
13
18
import org .elasticsearch .license .LicenseUtils ;
19
+ import org .elasticsearch .search .builder .SearchSourceBuilder ;
14
20
import org .elasticsearch .search .rank .RankBuilder ;
15
21
import org .elasticsearch .search .rank .RankDoc ;
16
22
import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
17
23
import org .elasticsearch .search .retriever .RetrieverBuilder ;
18
24
import org .elasticsearch .search .retriever .RetrieverParserContext ;
25
+ import org .elasticsearch .search .retriever .StandardRetrieverBuilder ;
19
26
import org .elasticsearch .xcontent .ConstructingObjectParser ;
20
27
import org .elasticsearch .xcontent .ParseField ;
21
28
import org .elasticsearch .xcontent .XContentBuilder ;
22
29
import org .elasticsearch .xcontent .XContentParser ;
23
30
import org .elasticsearch .xpack .core .XPackPlugin ;
31
+ import org .elasticsearch .xpack .rank .MultiFieldsInnerRetrieverUtils ;
24
32
25
33
import java .io .IOException ;
26
34
import java .util .ArrayList ;
29
37
import java .util .Map ;
30
38
import java .util .Objects ;
31
39
32
- import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
33
40
import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
34
41
35
42
/**
40
47
* formula.
41
48
*/
42
49
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder <RRFRetrieverBuilder > {
50
+ public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature ("rrf_retriever.multi_fields_query_format_support" );
43
51
44
52
public static final String NAME = "rrf" ;
45
53
46
54
public static final ParseField RETRIEVERS_FIELD = new ParseField ("retrievers" );
47
55
public static final ParseField RANK_CONSTANT_FIELD = new ParseField ("rank_constant" );
56
+ public static final ParseField FIELDS_FIELD = new ParseField ("fields" );
57
+ public static final ParseField QUERY_FIELD = new ParseField ("query" );
48
58
49
59
public static final int DEFAULT_RANK_CONSTANT = 60 ;
50
60
@ SuppressWarnings ("unchecked" )
@@ -53,22 +63,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
53
63
false ,
54
64
args -> {
55
65
List <RetrieverBuilder > childRetrievers = (List <RetrieverBuilder >) args [0 ];
56
- List <RetrieverSource > innerRetrievers = childRetrievers .stream ().map (RetrieverSource ::from ).toList ();
57
- int rankWindowSize = args [1 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [1 ];
58
- int rankConstant = args [2 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [2 ];
59
- return new RRFRetrieverBuilder (innerRetrievers , rankWindowSize , rankConstant );
66
+ List <String > fields = (List <String >) args [1 ];
67
+ String query = (String ) args [2 ];
68
+ int rankWindowSize = args [3 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [3 ];
69
+ int rankConstant = args [4 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [4 ];
70
+
71
+ List <RetrieverSource > innerRetrievers = childRetrievers != null
72
+ ? childRetrievers .stream ().map (RetrieverSource ::from ).toList ()
73
+ : List .of ();
74
+ return new RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant );
60
75
}
61
76
);
62
77
63
78
static {
64
- PARSER .declareObjectArray (constructorArg (), (p , c ) -> {
79
+ PARSER .declareObjectArray (optionalConstructorArg (), (p , c ) -> {
65
80
p .nextToken ();
66
81
String name = p .currentName ();
67
82
RetrieverBuilder retrieverBuilder = p .namedObject (RetrieverBuilder .class , name , c );
68
83
c .trackRetrieverUsage (retrieverBuilder .getName ());
69
84
p .nextToken ();
70
85
return retrieverBuilder ;
71
86
}, RETRIEVERS_FIELD );
87
+ PARSER .declareStringArray (optionalConstructorArg (), FIELDS_FIELD );
88
+ PARSER .declareString (optionalConstructorArg (), QUERY_FIELD );
72
89
PARSER .declareInt (optionalConstructorArg (), RANK_WINDOW_SIZE_FIELD );
73
90
PARSER .declareInt (optionalConstructorArg (), RANK_CONSTANT_FIELD );
74
91
RetrieverBuilder .declareBaseParserFields (PARSER );
@@ -81,25 +98,60 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
81
98
return PARSER .apply (parser , context );
82
99
}
83
100
101
+ private final List <String > fields ;
102
+ private final String query ;
84
103
private final int rankConstant ;
85
104
86
- public RRFRetrieverBuilder (int rankWindowSize , int rankConstant ) {
87
- this (new ArrayList <>() , rankWindowSize , rankConstant );
105
+ public RRFRetrieverBuilder (List < RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
106
+ this (childRetrievers , null , null , rankWindowSize , rankConstant );
88
107
}
89
108
90
- RRFRetrieverBuilder (List <RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
91
- super (childRetrievers , rankWindowSize );
109
+ public RRFRetrieverBuilder (
110
+ List <RetrieverSource > childRetrievers ,
111
+ List <String > fields ,
112
+ String query ,
113
+ int rankWindowSize ,
114
+ int rankConstant
115
+ ) {
116
+ // Use a mutable list for childRetrievers so that we can use addChild
117
+ super (childRetrievers == null ? new ArrayList <>() : new ArrayList <>(childRetrievers ), rankWindowSize );
118
+ this .fields = fields == null ? List .of () : List .copyOf (fields );
119
+ this .query = query ;
92
120
this .rankConstant = rankConstant ;
93
121
}
94
122
123
+ public int rankConstant () {
124
+ return rankConstant ;
125
+ }
126
+
95
127
@ Override
96
128
public String getName () {
97
129
return NAME ;
98
130
}
99
131
132
+ @ Override
133
+ public ActionRequestValidationException validate (
134
+ SearchSourceBuilder source ,
135
+ ActionRequestValidationException validationException ,
136
+ boolean isScroll ,
137
+ boolean allowPartialSearchResults
138
+ ) {
139
+ validationException = super .validate (source , validationException , isScroll , allowPartialSearchResults );
140
+ return MultiFieldsInnerRetrieverUtils .validateParams (
141
+ innerRetrievers ,
142
+ fields ,
143
+ query ,
144
+ getName (),
145
+ RETRIEVERS_FIELD .getPreferredName (),
146
+ FIELDS_FIELD .getPreferredName (),
147
+ QUERY_FIELD .getPreferredName (),
148
+ validationException
149
+ );
150
+ }
151
+
100
152
@ Override
101
153
protected RRFRetrieverBuilder clone (List <RetrieverSource > newRetrievers , List <QueryBuilder > newPreFilterQueryBuilders ) {
102
- RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .rankWindowSize , this .rankConstant );
154
+ RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .fields , this . query , this . rankWindowSize , this .rankConstant );
103
155
clone .preFilterQueryBuilders = newPreFilterQueryBuilders ;
104
156
clone .retrieverName = retrieverName ;
105
157
return clone ;
@@ -162,17 +214,72 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
162
214
return topResults ;
163
215
}
164
216
217
+ @ Override
218
+ protected RetrieverBuilder doRewrite (QueryRewriteContext ctx ) {
219
+ RetrieverBuilder rewritten = this ;
220
+
221
+ ResolvedIndices resolvedIndices = ctx .getResolvedIndices ();
222
+ if (resolvedIndices != null && query != null ) {
223
+ // TODO: Refactor duplicate code
224
+ // Using the multi-fields query format
225
+ var localIndicesMetadata = resolvedIndices .getConcreteLocalIndicesMetadata ();
226
+ if (localIndicesMetadata .size () > 1 ) {
227
+ throw new IllegalArgumentException (
228
+ "[" + NAME + "] cannot specify [" + QUERY_FIELD .getPreferredName () + "] when querying multiple indices"
229
+ );
230
+ } else if (resolvedIndices .getRemoteClusterIndices ().isEmpty () == false ) {
231
+ throw new IllegalArgumentException (
232
+ "[" + NAME + "] cannot specify [" + QUERY_FIELD .getPreferredName () + "] when querying remote indices"
233
+ );
234
+ }
235
+
236
+ List <RetrieverSource > fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils .generateInnerRetrievers (
237
+ fields ,
238
+ query ,
239
+ localIndicesMetadata .values (),
240
+ r -> {
241
+ List <RetrieverSource > retrievers = r .stream ()
242
+ .map (MultiFieldsInnerRetrieverUtils .WeightedRetrieverSource ::retrieverSource )
243
+ .toList ();
244
+ return new RRFRetrieverBuilder (retrievers , rankWindowSize , rankConstant );
245
+ },
246
+ w -> {
247
+ if (w != 1.0f ) {
248
+ throw new IllegalArgumentException (
249
+ "[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD .getPreferredName () + "]"
250
+ );
251
+ }
252
+ }
253
+ ).stream ().map (RetrieverSource ::from ).toList ();
254
+
255
+ if (fieldsInnerRetrievers .isEmpty () == false ) {
256
+ // TODO: This is a incomplete solution as it does not address other incomplete copy issues
257
+ // (such as dropping the retriever name and min score)
258
+ rewritten = new RRFRetrieverBuilder (fieldsInnerRetrievers , rankWindowSize , rankConstant );
259
+ rewritten .getPreFilterQueryBuilders ().addAll (preFilterQueryBuilders );
260
+ } else {
261
+ // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
262
+ rewritten = new StandardRetrieverBuilder (new MatchNoneQueryBuilder ());
263
+ }
264
+ }
265
+
266
+ return rewritten ;
267
+ }
268
+
165
269
// ---- FOR TESTING XCONTENT PARSING ----
166
270
167
271
@ Override
168
272
public boolean doEquals (Object o ) {
169
273
RRFRetrieverBuilder that = (RRFRetrieverBuilder ) o ;
170
- return super .doEquals (o ) && rankConstant == that .rankConstant ;
274
+ return super .doEquals (o )
275
+ && Objects .equals (fields , that .fields )
276
+ && Objects .equals (query , that .query )
277
+ && rankConstant == that .rankConstant ;
171
278
}
172
279
173
280
@ Override
174
281
public int doHashCode () {
175
- return Objects .hash (super .doHashCode (), rankConstant );
282
+ return Objects .hash (super .doHashCode (), fields , query , rankConstant );
176
283
}
177
284
178
285
@ Override
@@ -186,6 +293,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
186
293
builder .endArray ();
187
294
}
188
295
296
+ if (fields .isEmpty () == false ) {
297
+ builder .startArray (FIELDS_FIELD .getPreferredName ());
298
+ for (String field : fields ) {
299
+ builder .value (field );
300
+ }
301
+ builder .endArray ();
302
+ }
303
+ if (query != null ) {
304
+ builder .field (QUERY_FIELD .getPreferredName (), query );
305
+ }
306
+
189
307
builder .field (RANK_WINDOW_SIZE_FIELD .getPreferredName (), rankWindowSize );
190
308
builder .field (RANK_CONSTANT_FIELD .getPreferredName (), rankConstant );
191
309
}
0 commit comments