3737import  java .util .Map ;
3838import  java .util .Objects ;
3939
40+ import  static  org .elasticsearch .action .ValidateActions .addValidationError ;
4041import  static  org .elasticsearch .xcontent .ConstructingObjectParser .optionalConstructorArg ;
42+ import  static  org .elasticsearch .xpack .rank .rrf .RRFRetrieverComponent .DEFAULT_WEIGHT ;
4143
4244/** 
4345 * An rrf retriever is used to represent an rrf rank element, but 
@@ -57,33 +59,63 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5759    public  static  final  ParseField  QUERY_FIELD  = new  ParseField ("query" );
5860
5961    public  static  final  int  DEFAULT_RANK_CONSTANT  = 60 ;
62+ 
63+     private  final  float [] weights ;
64+ 
6065    @ SuppressWarnings ("unchecked" )
6166    static  final  ConstructingObjectParser <RRFRetrieverBuilder , RetrieverParserContext > PARSER  = new  ConstructingObjectParser <>(
6267        NAME ,
6368        false ,
6469        args  -> {
65-             List <RetrieverBuilder >  childRetrievers  = ( List <RetrieverBuilder >) args [0 ];
70+             List <Object >  rawRetrievers  = args [ 0 ] ==  null  ?  List . of () : ( List <Object >) args [0 ];
6671            List <String > fields  = (List <String >) args [1 ];
6772            String  query  = (String ) args [2 ];
6873            int  rankWindowSize  = args [3 ] == null  ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE  : (int ) args [3 ];
6974            int  rankConstant  = args [4 ] == null  ? DEFAULT_RANK_CONSTANT  : (int ) args [4 ];
7075
71-             List <RetrieverSource > innerRetrievers  = childRetrievers  != null 
72-                 ? childRetrievers .stream ().map (RetrieverSource ::from ).toList ()
73-                 : List .of ();
74-             return  new  RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant );
76+             List <RetrieverSource > innerRetrievers  = new  ArrayList <>(rawRetrievers .size ());
77+             float [] weights  = new  float [rawRetrievers .size ()];
78+             
79+             int  weightIndex  = 0 ;
80+             for  (Object  retrieverOrComponent  : rawRetrievers ) {
81+                 if  (retrieverOrComponent  instanceof  RRFRetrieverComponent  component ) {
82+                     innerRetrievers .add (RetrieverSource .from (component .retriever ));
83+                     weights [weightIndex ++] = component .weight ;
84+                 } else  {
85+                     RetrieverBuilder  bareRetriever  = (RetrieverBuilder ) retrieverOrComponent ;
86+                     innerRetrievers .add (RetrieverSource .from (bareRetriever ));
87+                     weights [weightIndex ++] = RRFRetrieverComponent .DEFAULT_WEIGHT ;
88+                 }
89+             }
90+ 
91+             return  new  RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant , weights );
7592        }
7693    );
7794
7895    static  {
79-         PARSER .declareObjectArray (optionalConstructorArg (), (p , c ) -> {
80-             p .nextToken ();
81-             String  name  = p .currentName ();
82-             RetrieverBuilder  retrieverBuilder  = p .namedObject (RetrieverBuilder .class , name , c );
83-             c .trackRetrieverUsage (retrieverBuilder .getName ());
84-             p .nextToken ();
85-             return  retrieverBuilder ;
86-         }, RETRIEVERS_FIELD );
96+         PARSER .declareObjectArray (optionalConstructorArg (),
97+             (p , c ) -> {
98+                 List <Object > list  = new  ArrayList <>();
99+                 while  (p .nextToken () != XContentParser .Token .END_ARRAY ) {
100+                     if  (p .currentToken () == XContentParser .Token .START_OBJECT  &&
101+                         p .nextToken () == XContentParser .Token .FIELD_NAME  &&
102+                         RRFRetrieverComponent .RETRIEVER_FIELD .match (p .currentName (), p .getDeprecationHandler ())) {
103+                         // Handle wrapped retriever with weight 
104+                         list .add (RRFRetrieverComponent .fromXContent (p , c ));
105+                     } else  {
106+                         // Handle bare retriever (legacy format) 
107+                         String  name  = p .currentName ();
108+                         RetrieverBuilder  retrieverBuilder  = p .namedObject (RetrieverBuilder .class , name , c );
109+                         c .trackRetrieverUsage (retrieverBuilder .getName ());
110+                         p .nextToken ();
111+                         list .add (retrieverBuilder );
112+                     }
113+                 }
114+                 return  list ;
115+             },
116+             RETRIEVERS_FIELD 
117+         );
118+         
87119        PARSER .declareStringArray (optionalConstructorArg (), FIELDS_FIELD );
88120        PARSER .declareString (optionalConstructorArg (), QUERY_FIELD );
89121        PARSER .declareInt (optionalConstructorArg (), RANK_WINDOW_SIZE_FIELD );
@@ -103,21 +135,30 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
103135    private  final  int  rankConstant ;
104136
105137    public  RRFRetrieverBuilder (List <RetrieverSource > childRetrievers , int  rankWindowSize , int  rankConstant ) {
106-         this (childRetrievers , null , null , rankWindowSize , rankConstant );
138+         this (childRetrievers , null , null , rankWindowSize , rankConstant , createDefaultWeights (childRetrievers ));
139+     }
140+ 
141+     private  static  float [] createDefaultWeights (List <RetrieverSource > retrievers ) {
142+         int  size  = retrievers  == null  ? 0  : retrievers .size ();
143+         float [] defaultWeights  = new  float [size ];
144+         Arrays .fill (defaultWeights , DEFAULT_WEIGHT );
145+         return  defaultWeights ;
107146    }
108147
109148    public  RRFRetrieverBuilder (
110149        List <RetrieverSource > childRetrievers ,
111150        List <String > fields ,
112151        String  query ,
113152        int  rankWindowSize ,
114-         int  rankConstant 
153+         int  rankConstant ,
154+         float [] weights 
115155    ) {
116156        // Use a mutable list for childRetrievers so that we can use addChild 
117157        super (childRetrievers  == null  ? new  ArrayList <>() : new  ArrayList <>(childRetrievers ), rankWindowSize );
118158        this .fields  = fields  == null  ? null  : List .copyOf (fields );
119159        this .query  = query ;
120160        this .rankConstant  = rankConstant ;
161+         this .weights  = weights ;
121162    }
122163
123164    public  int  rankConstant () {
@@ -137,6 +178,14 @@ public ActionRequestValidationException validate(
137178        boolean  allowPartialSearchResults 
138179    ) {
139180        validationException  = super .validate (source , validationException , isScroll , allowPartialSearchResults );
181+ 
182+         if  (this .weights  != null ) {
183+             for  (float  weight  : this .weights ) {
184+                 if  (weight  < 0 ) {
185+                     validationException  = addValidationError ("[weight] must be non-negative, found ["  + weight  + "]" , validationException );
186+                 }
187+             }
188+         }
140189        return  MultiFieldsInnerRetrieverUtils .validateParams (
141190            innerRetrievers ,
142191            fields ,
@@ -151,7 +200,7 @@ public ActionRequestValidationException validate(
151200
152201    @ Override 
153202    protected  RRFRetrieverBuilder  clone (List <RetrieverSource > newRetrievers , List <QueryBuilder > newPreFilterQueryBuilders ) {
154-         RRFRetrieverBuilder  clone  = new  RRFRetrieverBuilder (newRetrievers , this .fields , this .query , this .rankWindowSize , this .rankConstant );
203+         RRFRetrieverBuilder  clone  = new  RRFRetrieverBuilder (newRetrievers , this .fields , this .query , this .rankWindowSize , this .rankConstant ,  this . weights );
155204        clone .preFilterQueryBuilders  = newPreFilterQueryBuilders ;
156205        clone .retrieverName  = retrieverName ;
157206        return  clone ;
@@ -183,7 +232,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
183232
184233                    // calculate the current rrf score for this document 
185234                    // later used to sort and covert to a rank 
186-                     value .score  += 1.0f  / (rankConstant  + frank );
235+                     value .score  += this . weights [ findex ] * ( 1.0f  / (rankConstant  + frank ) );
187236
188237                    if  (explain  && value .positions  != null  && value .scores  != null ) {
189238                        // record the position for each query 
@@ -233,29 +282,34 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
233282                );
234283            }
235284
236-             List <RetrieverSource > fieldsInnerRetrievers  = MultiFieldsInnerRetrieverUtils .generateInnerRetrievers (
285+             List <RetrieverBuilder > fieldsInnerRetrievers  = MultiFieldsInnerRetrieverUtils .generateInnerRetrievers (
237286                fields ,
238287                query ,
239288                localIndicesMetadata .values (),
240289                r  -> {
241-                     List <RetrieverSource > retrievers  = r .stream ()
242-                         .map (MultiFieldsInnerRetrieverUtils .WeightedRetrieverSource ::retrieverSource )
243-                         .toList ();
244-                     return  new  RRFRetrieverBuilder (retrievers , rankWindowSize , rankConstant );
290+                     List <RetrieverSource > retrievers  = new  ArrayList <>(r .size ());
291+                     float [] weights  = new  float [r .size ()];
292+                     int  i  = 0 ;
293+                     for (var  retriever : r ) {
294+                         retrievers .add (retriever .retrieverSource ());
295+                         weights [i ++] = retriever .weight ();
296+                     }
297+                     return  new  RRFRetrieverBuilder (retrievers , null , null , rankWindowSize , rankConstant , weights );
245298                },
246299                w  -> {
247-                     if  (w  !=  1.0f ) {
300+                     if  (w  <  0 ) {
248301                        throw  new  IllegalArgumentException (
249-                             "["  + NAME  + "] does not support  per-field weights in ["   +  FIELDS_FIELD . getPreferredName () +  "] "
302+                             "["  + NAME  + "] per-field weights must be non-negative " 
250303                        );
251304                    }
252305                }
253-             ). stream (). map ( RetrieverSource :: from ). toList () ;
306+             );
254307
255308            if  (fieldsInnerRetrievers .isEmpty () == false ) {
256309                // TODO: This is a incomplete solution as it does not address other incomplete copy issues 
257310                // (such as dropping the retriever name and min score) 
258-                 rewritten  = new  RRFRetrieverBuilder (fieldsInnerRetrievers , rankWindowSize , rankConstant );
311+                 RRFRetrieverBuilder  g  = (RRFRetrieverBuilder ) fieldsInnerRetrievers .get (0 );
312+                 rewritten  = new  RRFRetrieverBuilder (g .innerRetrievers , null , null , rankWindowSize , rankConstant , g .weights );
259313                rewritten .getPreFilterQueryBuilders ().addAll (preFilterQueryBuilders );
260314            } else  {
261315                // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices 
@@ -274,21 +328,22 @@ public boolean doEquals(Object o) {
274328        return  super .doEquals (o )
275329            && Objects .equals (fields , that .fields )
276330            && Objects .equals (query , that .query )
277-             && rankConstant  == that .rankConstant ;
331+             && rankConstant  == that .rankConstant 
332+             && Arrays .equals (weights , that .weights );
278333    }
279334
280335    @ Override 
281336    public  int  doHashCode () {
282-         return  Objects .hash (super .doHashCode (), fields , query , rankConstant );
337+         return  Objects .hash (super .doHashCode (), fields , query , rankConstant ,  Arrays . hashCode ( weights ) );
283338    }
284339
285340    @ Override 
286341    public  void  doToXContent (XContentBuilder  builder , Params  params ) throws  IOException  {
287342        if  (innerRetrievers .isEmpty () == false ) {
288343            builder .startArray (RETRIEVERS_FIELD .getPreferredName ());
289- 
290-             for  (var   entry  :  innerRetrievers ) {
291-                 entry . retriever ().toXContent (builder , params );
344+              
345+             for  (int   i  =  0 ;  i  <  innerRetrievers . size ();  i ++ ) {
346+                 new   RRFRetrieverComponent ( innerRetrievers . get ( i ). retriever (),  this . weights [ i ] ).toXContent (builder , params );
292347            }
293348            builder .endArray ();
294349        }
0 commit comments