88package org .elasticsearch .xpack .rank .rrf ;
99
1010import org .apache .lucene .search .ScoreDoc ;
11+ import org .elasticsearch .action .ActionRequestValidationException ;
12+ import org .elasticsearch .action .ResolvedIndices ;
1113import org .elasticsearch .common .ParsingException ;
1214import org .elasticsearch .common .util .Maps ;
1315import org .elasticsearch .features .NodeFeature ;
16+ import org .elasticsearch .index .query .MatchNoneQueryBuilder ;
1417import org .elasticsearch .index .query .QueryBuilder ;
18+ import org .elasticsearch .index .query .QueryRewriteContext ;
1519import org .elasticsearch .license .LicenseUtils ;
20+ import org .elasticsearch .search .builder .SearchSourceBuilder ;
1621import org .elasticsearch .search .rank .RankBuilder ;
1722import org .elasticsearch .search .rank .RankDoc ;
1823import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
1924import org .elasticsearch .search .retriever .RetrieverBuilder ;
2025import org .elasticsearch .search .retriever .RetrieverParserContext ;
26+ import org .elasticsearch .search .retriever .StandardRetrieverBuilder ;
2127import org .elasticsearch .xcontent .ConstructingObjectParser ;
2228import org .elasticsearch .xcontent .ParseField ;
2329import org .elasticsearch .xcontent .XContentBuilder ;
2430import org .elasticsearch .xcontent .XContentParser ;
2531import org .elasticsearch .xpack .core .XPackPlugin ;
32+ import org .elasticsearch .xpack .rank .MultiFieldsInnerRetrieverUtils ;
2633
2734import java .io .IOException ;
2835import java .util .ArrayList ;
3138import java .util .Map ;
3239import java .util .Objects ;
3340
34- import static org .elasticsearch .xcontent .ConstructingObjectParser .constructorArg ;
3541import static org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
3642
3743/**
4248 * formula.
4349 */
4450public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder <RRFRetrieverBuilder > {
51+ public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature ("rrf_retriever.multi_fields_query_format_support" );
4552
4653 public static final String NAME = "rrf" ;
4754 public static final NodeFeature RRF_RETRIEVER_SUPPORTED = new NodeFeature ("rrf_retriever_supported" , true );
4855 public static final NodeFeature RRF_RETRIEVER_COMPOSITION_SUPPORTED = new NodeFeature ("rrf_retriever_composition_supported" , true );
4956
5057 public static final ParseField RETRIEVERS_FIELD = new ParseField ("retrievers" );
5158 public static final ParseField RANK_CONSTANT_FIELD = new ParseField ("rank_constant" );
59+ public static final ParseField FIELDS_FIELD = new ParseField ("fields" );
60+ public static final ParseField QUERY_FIELD = new ParseField ("query" );
5261
5362 public static final int DEFAULT_RANK_CONSTANT = 60 ;
5463 @ SuppressWarnings ("unchecked" )
@@ -57,22 +66,29 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5766 false ,
5867 args -> {
5968 List <RetrieverBuilder > childRetrievers = (List <RetrieverBuilder >) args [0 ];
60- List <RetrieverSource > innerRetrievers = childRetrievers .stream ().map (RetrieverSource ::from ).toList ();
61- int rankWindowSize = args [1 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [1 ];
62- int rankConstant = args [2 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [2 ];
63- return new RRFRetrieverBuilder (innerRetrievers , rankWindowSize , rankConstant );
69+ List <String > fields = (List <String >) args [1 ];
70+ String query = (String ) args [2 ];
71+ int rankWindowSize = args [3 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [3 ];
72+ int rankConstant = args [4 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [4 ];
73+
74+ List <RetrieverSource > innerRetrievers = childRetrievers != null
75+ ? childRetrievers .stream ().map (RetrieverSource ::from ).toList ()
76+ : List .of ();
77+ return new RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant );
6478 }
6579 );
6680
6781 static {
68- PARSER .declareObjectArray (constructorArg (), (p , c ) -> {
82+ PARSER .declareObjectArray (optionalConstructorArg (), (p , c ) -> {
6983 p .nextToken ();
7084 String name = p .currentName ();
7185 RetrieverBuilder retrieverBuilder = p .namedObject (RetrieverBuilder .class , name , c );
7286 c .trackRetrieverUsage (retrieverBuilder .getName ());
7387 p .nextToken ();
7488 return retrieverBuilder ;
7589 }, RETRIEVERS_FIELD );
90+ PARSER .declareStringArray (optionalConstructorArg (), FIELDS_FIELD );
91+ PARSER .declareString (optionalConstructorArg (), QUERY_FIELD );
7692 PARSER .declareInt (optionalConstructorArg (), RANK_WINDOW_SIZE_FIELD );
7793 PARSER .declareInt (optionalConstructorArg (), RANK_CONSTANT_FIELD );
7894 RetrieverBuilder .declareBaseParserFields (NAME , PARSER );
@@ -91,25 +107,60 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
91107 return PARSER .apply (parser , context );
92108 }
93109
110+ private final List <String > fields ;
111+ private final String query ;
94112 private final int rankConstant ;
95113
96- public RRFRetrieverBuilder (int rankWindowSize , int rankConstant ) {
97- this (new ArrayList <>() , rankWindowSize , rankConstant );
114+ public RRFRetrieverBuilder (List < RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
115+ this (childRetrievers , null , null , rankWindowSize , rankConstant );
98116 }
99117
100- RRFRetrieverBuilder (List <RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
101- super (childRetrievers , rankWindowSize );
118+ public RRFRetrieverBuilder (
119+ List <RetrieverSource > childRetrievers ,
120+ List <String > fields ,
121+ String query ,
122+ int rankWindowSize ,
123+ int rankConstant
124+ ) {
125+ // Use a mutable list for childRetrievers so that we can use addChild
126+ super (childRetrievers == null ? new ArrayList <>() : new ArrayList <>(childRetrievers ), rankWindowSize );
127+ this .fields = fields == null ? List .of () : List .copyOf (fields );
128+ this .query = query ;
102129 this .rankConstant = rankConstant ;
103130 }
104131
132+ public int rankConstant () {
133+ return rankConstant ;
134+ }
135+
105136 @ Override
106137 public String getName () {
107138 return NAME ;
108139 }
109140
141+ @ Override
142+ public ActionRequestValidationException validate (
143+ SearchSourceBuilder source ,
144+ ActionRequestValidationException validationException ,
145+ boolean isScroll ,
146+ boolean allowPartialSearchResults
147+ ) {
148+ validationException = super .validate (source , validationException , isScroll , allowPartialSearchResults );
149+ return MultiFieldsInnerRetrieverUtils .validateParams (
150+ innerRetrievers ,
151+ fields ,
152+ query ,
153+ getName (),
154+ RETRIEVERS_FIELD .getPreferredName (),
155+ FIELDS_FIELD .getPreferredName (),
156+ QUERY_FIELD .getPreferredName (),
157+ validationException
158+ );
159+ }
160+
110161 @ Override
111162 protected RRFRetrieverBuilder clone (List <RetrieverSource > newRetrievers , List <QueryBuilder > newPreFilterQueryBuilders ) {
112- RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .rankWindowSize , this .rankConstant );
163+ RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .fields , this . query , this . rankWindowSize , this .rankConstant );
113164 clone .preFilterQueryBuilders = newPreFilterQueryBuilders ;
114165 clone .retrieverName = retrieverName ;
115166 return clone ;
@@ -172,17 +223,72 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
172223 return topResults ;
173224 }
174225
226+ @ Override
227+ protected RetrieverBuilder doRewrite (QueryRewriteContext ctx ) {
228+ RetrieverBuilder rewritten = this ;
229+
230+ ResolvedIndices resolvedIndices = ctx .getResolvedIndices ();
231+ if (resolvedIndices != null && query != null ) {
232+ // TODO: Refactor duplicate code
233+ // Using the multi-fields query format
234+ var localIndicesMetadata = resolvedIndices .getConcreteLocalIndicesMetadata ();
235+ if (localIndicesMetadata .size () > 1 ) {
236+ throw new IllegalArgumentException (
237+ "[" + NAME + "] cannot specify [" + QUERY_FIELD .getPreferredName () + "] when querying multiple indices"
238+ );
239+ } else if (resolvedIndices .getRemoteClusterIndices ().isEmpty () == false ) {
240+ throw new IllegalArgumentException (
241+ "[" + NAME + "] cannot specify [" + QUERY_FIELD .getPreferredName () + "] when querying remote indices"
242+ );
243+ }
244+
245+ List <RetrieverSource > fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils .generateInnerRetrievers (
246+ fields ,
247+ query ,
248+ localIndicesMetadata .values (),
249+ r -> {
250+ List <RetrieverSource > retrievers = r .stream ()
251+ .map (MultiFieldsInnerRetrieverUtils .WeightedRetrieverSource ::retrieverSource )
252+ .toList ();
253+ return new RRFRetrieverBuilder (retrievers , rankWindowSize , rankConstant );
254+ },
255+ w -> {
256+ if (w != 1.0f ) {
257+ throw new IllegalArgumentException (
258+ "[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD .getPreferredName () + "]"
259+ );
260+ }
261+ }
262+ ).stream ().map (RetrieverSource ::from ).toList ();
263+
264+ if (fieldsInnerRetrievers .isEmpty () == false ) {
265+ // TODO: This is a incomplete solution as it does not address other incomplete copy issues
266+ // (such as dropping the retriever name and min score)
267+ rewritten = new RRFRetrieverBuilder (fieldsInnerRetrievers , rankWindowSize , rankConstant );
268+ rewritten .getPreFilterQueryBuilders ().addAll (preFilterQueryBuilders );
269+ } else {
270+ // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
271+ rewritten = new StandardRetrieverBuilder (new MatchNoneQueryBuilder ());
272+ }
273+ }
274+
275+ return rewritten ;
276+ }
277+
175278 // ---- FOR TESTING XCONTENT PARSING ----
176279
177280 @ Override
178281 public boolean doEquals (Object o ) {
179282 RRFRetrieverBuilder that = (RRFRetrieverBuilder ) o ;
180- return super .doEquals (o ) && rankConstant == that .rankConstant ;
283+ return super .doEquals (o )
284+ && Objects .equals (fields , that .fields )
285+ && Objects .equals (query , that .query )
286+ && rankConstant == that .rankConstant ;
181287 }
182288
183289 @ Override
184290 public int doHashCode () {
185- return Objects .hash (super .doHashCode (), rankConstant );
291+ return Objects .hash (super .doHashCode (), fields , query , rankConstant );
186292 }
187293
188294 @ Override
@@ -196,6 +302,17 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
196302 builder .endArray ();
197303 }
198304
305+ if (fields .isEmpty () == false ) {
306+ builder .startArray (FIELDS_FIELD .getPreferredName ());
307+ for (String field : fields ) {
308+ builder .value (field );
309+ }
310+ builder .endArray ();
311+ }
312+ if (query != null ) {
313+ builder .field (QUERY_FIELD .getPreferredName (), query );
314+ }
315+
199316 builder .field (RANK_WINDOW_SIZE_FIELD .getPreferredName (), rankWindowSize );
200317 builder .field (RANK_CONSTANT_FIELD .getPreferredName (), rankConstant );
201318 }
0 commit comments