88package org .elasticsearch .xpack .rank .rrf ;
99
1010import org .apache .lucene .search .ScoreDoc ;
11+ import org .elasticsearch .action .ActionRequestValidationException ;
12+ import org .elasticsearch .action .ResolvedIndices ;
1113import org .elasticsearch .common .util .Maps ;
14+ import org .elasticsearch .index .query .MatchNoneQueryBuilder ;
1215import org .elasticsearch .index .query .QueryBuilder ;
16+ import org .elasticsearch .index .query .QueryRewriteContext ;
1317import org .elasticsearch .license .LicenseUtils ;
18+ import org .elasticsearch .search .builder .SearchSourceBuilder ;
1419import org .elasticsearch .search .rank .RankBuilder ;
1520import org .elasticsearch .search .rank .RankDoc ;
1621import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
1722import org .elasticsearch .search .retriever .RetrieverBuilder ;
1823import org .elasticsearch .search .retriever .RetrieverParserContext ;
24+ import org .elasticsearch .search .retriever .StandardRetrieverBuilder ;
1925import org .elasticsearch .xcontent .ConstructingObjectParser ;
2026import org .elasticsearch .xcontent .ParseField ;
2127import org .elasticsearch .xcontent .XContentBuilder ;
2228import org .elasticsearch .xcontent .XContentParser ;
2329import org .elasticsearch .xpack .core .XPackPlugin ;
30+ import org .elasticsearch .xpack .rank .simplified .SimplifiedInnerRetrieverUtils ;
2431
2532import java .io .IOException ;
2633import java .util .ArrayList ;
2936import java .util .Map ;
3037import java .util .Objects ;
3138
32- import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
3339import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
3440
3541/**
@@ -45,6 +51,8 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
4551
4652 public static final ParseField RETRIEVERS_FIELD = new ParseField ("retrievers" );
4753 public static final ParseField RANK_CONSTANT_FIELD = new ParseField ("rank_constant" );
54+ public static final ParseField FIELDS_FIELD = new ParseField ("fields" );
55+ public static final ParseField QUERY_FIELD = new ParseField ("query" );
4856
4957 public static final int DEFAULT_RANK_CONSTANT = 60 ;
5058 @ SuppressWarnings ("unchecked" )
@@ -53,22 +61,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5361 false ,
5462 args -> {
5563 List <RetrieverBuilder > childRetrievers = (List <RetrieverBuilder >) args [0 ];
56- List <RetrieverSource > innerRetrievers = childRetrievers .stream ().map (RetrieverSource ::from ).toList ();
57- int rankWindowSize = args [1 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [1 ];
58- int rankConstant = args [2 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [2 ];
59- return new RRFRetrieverBuilder (innerRetrievers , rankWindowSize , rankConstant );
64+ List <String > fields = (List <String >) args [1 ];
65+ String query = (String ) args [2 ];
66+ int rankWindowSize = args [3 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [3 ];
67+ int rankConstant = args [4 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [4 ];
68+
69+ List <RetrieverSource > innerRetrievers = childRetrievers != null
70+ ? childRetrievers .stream ().map (r -> new RetrieverSource (r , null )).toList ()
71+ : List .of ();
72+ return new RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant );
6073 }
6174 );
6275
6376 static {
64- PARSER .declareObjectArray (constructorArg (), (p , c ) -> {
77+ PARSER .declareObjectArray (optionalConstructorArg (), (p , c ) -> {
6578 p .nextToken ();
6679 String name = p .currentName ();
6780 RetrieverBuilder retrieverBuilder = p .namedObject (RetrieverBuilder .class , name , c );
6881 c .trackRetrieverUsage (retrieverBuilder .getName ());
6982 p .nextToken ();
7083 return retrieverBuilder ;
7184 }, RETRIEVERS_FIELD );
85+ PARSER .declareStringArray (optionalConstructorArg (), FIELDS_FIELD );
86+ PARSER .declareString (optionalConstructorArg (), QUERY_FIELD );
7287 PARSER .declareInt (optionalConstructorArg (), RANK_WINDOW_SIZE_FIELD );
7388 PARSER .declareInt (optionalConstructorArg (), RANK_CONSTANT_FIELD );
7489 RetrieverBuilder .declareBaseParserFields (PARSER );
@@ -81,25 +96,63 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
8196 return PARSER .apply (parser , context );
8297 }
8398
99+ private final List <String > fields ;
100+ private final String query ;
84101 private final int rankConstant ;
85102
86- public RRFRetrieverBuilder (int rankWindowSize , int rankConstant ) {
87- this (new ArrayList <>() , rankWindowSize , rankConstant );
103+ public RRFRetrieverBuilder (List < RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
104+ this (childRetrievers , null , null , rankWindowSize , rankConstant );
88105 }
89106
90- RRFRetrieverBuilder (List <RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
91- super (childRetrievers , rankWindowSize );
107+ public RRFRetrieverBuilder (
108+ List <RetrieverSource > childRetrievers ,
109+ List <String > fields ,
110+ String query ,
111+ int rankWindowSize ,
112+ int rankConstant
113+ ) {
114+ // Use a mutable list for childRetrievers so that we can use addChild
115+ super (childRetrievers == null ? new ArrayList <>() : new ArrayList <>(childRetrievers ), rankWindowSize );
116+ this .fields = fields == null ? List .of () : List .copyOf (fields );
117+ this .query = query ;
92118 this .rankConstant = rankConstant ;
119+
120+ // TODO: Validate simplified query format args here?
121+ // Otherwise some of the validation is skipped when creating the retriever programmatically.
122+ }
123+
124+ public int rankConstant () {
125+ return rankConstant ;
93126 }
94127
95128 @ Override
96129 public String getName () {
97130 return NAME ;
98131 }
99132
133+ @ Override
134+ public ActionRequestValidationException validate (
135+ SearchSourceBuilder source ,
136+ ActionRequestValidationException validationException ,
137+ boolean isScroll ,
138+ boolean allowPartialSearchResults
139+ ) {
140+ validationException = super .validate (source , validationException , isScroll , allowPartialSearchResults );
141+ return SimplifiedInnerRetrieverUtils .validateSimplifiedFormatParams (
142+ innerRetrievers ,
143+ fields ,
144+ query ,
145+ getName (),
146+ RETRIEVERS_FIELD .getPreferredName (),
147+ FIELDS_FIELD .getPreferredName (),
148+ QUERY_FIELD .getPreferredName (),
149+ validationException
150+ );
151+ }
152+
100153 @ Override
101154 protected RRFRetrieverBuilder clone (List <RetrieverSource > newRetrievers , List <QueryBuilder > newPreFilterQueryBuilders ) {
102- RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .rankWindowSize , this .rankConstant );
155+ RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .fields , this . query , this . rankWindowSize , this .rankConstant );
103156 clone .preFilterQueryBuilders = newPreFilterQueryBuilders ;
104157 clone .retrieverName = retrieverName ;
105158 return clone ;
@@ -162,17 +215,68 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
162215 return topResults ;
163216 }
164217
218+ @ Override
219+ protected RetrieverBuilder doRewrite (QueryRewriteContext ctx ) {
220+ RetrieverBuilder rewritten = this ;
221+
222+ ResolvedIndices resolvedIndices = ctx .getResolvedIndices ();
223+ if (resolvedIndices != null && query != null ) {
224+ // Using the simplified query format
225+ var localIndicesMetadata = resolvedIndices .getConcreteLocalIndicesMetadata ();
226+ if (localIndicesMetadata .size () > 1 ) {
227+ throw new IllegalArgumentException (
228+ "[" + NAME + "] does not support the simplified query format when querying multiple indices"
229+ );
230+ } else if (resolvedIndices .getRemoteClusterIndices ().isEmpty () == false ) {
231+ throw new IllegalArgumentException (
232+ "[" + NAME + "] does not support the simplified query format when querying remote indices"
233+ );
234+ }
235+
236+ List <RetrieverSource > fieldsInnerRetrievers = SimplifiedInnerRetrieverUtils .generateInnerRetrievers (
237+ fields ,
238+ query ,
239+ localIndicesMetadata .values (),
240+ r -> {
241+ List <RetrieverSource > retrievers = r .stream ()
242+ .map (SimplifiedInnerRetrieverUtils .WeightedRetrieverSource ::retrieverSource )
243+ .toList ();
244+ return new RRFRetrieverBuilder (retrievers , rankWindowSize , rankConstant );
245+ },
246+ w -> {
247+ if (w != 1.0f ) {
248+ throw new IllegalArgumentException (
249+ "[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD .getPreferredName () + "]"
250+ );
251+ }
252+ }
253+ ).stream ().map (CompoundRetrieverBuilder ::convertToRetrieverSource ).toList ();
254+
255+ if (fieldsInnerRetrievers .isEmpty () == false ) {
256+ rewritten = new RRFRetrieverBuilder (fieldsInnerRetrievers , rankWindowSize , rankConstant );
257+ } else {
258+ // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
259+ rewritten = new StandardRetrieverBuilder (new MatchNoneQueryBuilder ());
260+ }
261+ }
262+
263+ return rewritten ;
264+ }
265+
165266 // ---- FOR TESTING XCONTENT PARSING ----
166267
167268 @ Override
168269 public boolean doEquals (Object o ) {
169270 RRFRetrieverBuilder that = (RRFRetrieverBuilder ) o ;
170- return super .doEquals (o ) && rankConstant == that .rankConstant ;
271+ return super .doEquals (o )
272+ && Objects .equals (fields , that .fields )
273+ && Objects .equals (query , that .query )
274+ && rankConstant == that .rankConstant ;
171275 }
172276
173277 @ Override
174278 public int doHashCode () {
175- return Objects .hash (super .doHashCode (), rankConstant );
279+ return Objects .hash (super .doHashCode (), fields , query , rankConstant );
176280 }
177281
178282 @ Override
@@ -186,6 +290,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
186290 builder .endArray ();
187291 }
188292
293+ if (fields .isEmpty () == false ) {
294+ builder .startArray (FIELDS_FIELD .getPreferredName ());
295+ for (String field : fields ) {
296+ builder .value (field );
297+ }
298+ builder .endArray ();
299+ }
300+ if (query != null ) {
301+ builder .field (QUERY_FIELD .getPreferredName (), query );
302+ }
303+
189304 builder .field (RANK_WINDOW_SIZE_FIELD .getPreferredName (), rankWindowSize );
190305 builder .field (RANK_CONSTANT_FIELD .getPreferredName (), rankConstant );
191306 }
0 commit comments