2020import org .elasticsearch .search .rank .RankBuilder ;
2121import org .elasticsearch .search .rank .RankDoc ;
2222import org .elasticsearch .search .retriever .CompoundRetrieverBuilder ;
23+ import org .elasticsearch .search .retriever .CompoundRetrieverBuilder .RetrieverSource ;
2324import org .elasticsearch .search .retriever .RetrieverBuilder ;
2425import org .elasticsearch .search .retriever .RetrieverParserContext ;
2526import org .elasticsearch .search .retriever .StandardRetrieverBuilder ;
3738import java .util .Map ;
3839import java .util .Objects ;
3940
40- import static org .elasticsearch .xcontent . ConstructingObjectParser . optionalConstructorArg ;
41+ import static org .elasticsearch .xpack . rank . rrf . RRFRetrieverComponent . DEFAULT_WEIGHT ;
4142
4243/**
4344 * An rrf retriever is used to represent an rrf rank element, but
4849 */
4950public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder <RRFRetrieverBuilder > {
5051 public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature ("rrf_retriever.multi_fields_query_format_support" );
52+ public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature ("rrf_retriever.weighted_support" );
5153
5254 public static final String NAME = "rrf" ;
5355
@@ -57,37 +59,38 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5759 public static final ParseField QUERY_FIELD = new ParseField ("query" );
5860
5961 public static final int DEFAULT_RANK_CONSTANT = 60 ;
62+
63+ private final float [] weights ;
64+
6065 @ SuppressWarnings ("unchecked" )
6166 static final ConstructingObjectParser <RRFRetrieverBuilder , RetrieverParserContext > PARSER = new ConstructingObjectParser <>(
6267 NAME ,
6368 false ,
6469 args -> {
65- List <RetrieverBuilder > childRetrievers = ( List <RetrieverBuilder >) args [0 ];
70+ List <RRFRetrieverComponent > retrieverComponents = args [ 0 ] == null ? List . of () : ( List <RRFRetrieverComponent >) args [0 ];
6671 List <String > fields = (List <String >) args [1 ];
6772 String query = (String ) args [2 ];
6873 int rankWindowSize = args [3 ] == null ? RankBuilder .DEFAULT_RANK_WINDOW_SIZE : (int ) args [3 ];
6974 int rankConstant = args [4 ] == null ? DEFAULT_RANK_CONSTANT : (int ) args [4 ];
7075
71- List <RetrieverSource > innerRetrievers = childRetrievers != null
72- ? childRetrievers .stream ().map (RetrieverSource ::from ).toList ()
73- : List .of ();
74- return new RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant );
76+ int n = retrieverComponents .size ();
77+ List <RetrieverSource > innerRetrievers = new ArrayList <>(n );
78+ float [] weights = new float [n ];
79+ for (int i = 0 ; i < n ; i ++) {
80+ RRFRetrieverComponent component = retrieverComponents .get (i );
81+ innerRetrievers .add (RetrieverSource .from (component .retriever ()));
82+ weights [i ] = component .weight ();
83+ }
84+ return new RRFRetrieverBuilder (innerRetrievers , fields , query , rankWindowSize , rankConstant , weights );
7585 }
7686 );
7787
7888 static {
79- PARSER .declareObjectArray (optionalConstructorArg (), (p , c ) -> {
80- p .nextToken ();
81- String name = p .currentName ();
82- RetrieverBuilder retrieverBuilder = p .namedObject (RetrieverBuilder .class , name , c );
83- c .trackRetrieverUsage (retrieverBuilder .getName ());
84- p .nextToken ();
85- return retrieverBuilder ;
86- }, RETRIEVERS_FIELD );
87- PARSER .declareStringArray (optionalConstructorArg (), FIELDS_FIELD );
88- PARSER .declareString (optionalConstructorArg (), QUERY_FIELD );
89- PARSER .declareInt (optionalConstructorArg (), RANK_WINDOW_SIZE_FIELD );
90- PARSER .declareInt (optionalConstructorArg (), RANK_CONSTANT_FIELD );
89+ PARSER .declareObjectArray (ConstructingObjectParser .optionalConstructorArg (), RRFRetrieverComponent ::fromXContent , RETRIEVERS_FIELD );
90+ PARSER .declareStringArray (ConstructingObjectParser .optionalConstructorArg (), FIELDS_FIELD );
91+ PARSER .declareString (ConstructingObjectParser .optionalConstructorArg (), QUERY_FIELD );
92+ PARSER .declareInt (ConstructingObjectParser .optionalConstructorArg (), RANK_WINDOW_SIZE_FIELD );
93+ PARSER .declareInt (ConstructingObjectParser .optionalConstructorArg (), RANK_CONSTANT_FIELD );
9194 RetrieverBuilder .declareBaseParserFields (PARSER );
9295 }
9396
@@ -103,27 +106,46 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
103106 private final int rankConstant ;
104107
105108 public RRFRetrieverBuilder (List <RetrieverSource > childRetrievers , int rankWindowSize , int rankConstant ) {
106- this (childRetrievers , null , null , rankWindowSize , rankConstant );
109+ this (childRetrievers , null , null , rankWindowSize , rankConstant , createDefaultWeights (childRetrievers ));
110+ }
111+
112+ private static float [] createDefaultWeights (List <?> retrievers ) {
113+ int size = retrievers == null ? 0 : retrievers .size ();
114+ float [] defaultWeights = new float [size ];
115+ Arrays .fill (defaultWeights , DEFAULT_WEIGHT );
116+ return defaultWeights ;
107117 }
108118
109119 public RRFRetrieverBuilder (
110120 List <RetrieverSource > childRetrievers ,
111121 List <String > fields ,
112122 String query ,
113123 int rankWindowSize ,
114- int rankConstant
124+ int rankConstant ,
125+ float [] weights
115126 ) {
116127 // Use a mutable list for childRetrievers so that we can use addChild
117128 super (childRetrievers == null ? new ArrayList <>() : new ArrayList <>(childRetrievers ), rankWindowSize );
118129 this .fields = fields == null ? null : List .copyOf (fields );
119130 this .query = query ;
120131 this .rankConstant = rankConstant ;
132+ Objects .requireNonNull (weights , "weights must not be null" );
133+ if (weights .length != innerRetrievers .size ()) {
134+ throw new IllegalArgumentException (
135+ "weights array length [" + weights .length + "] must match retrievers count [" + innerRetrievers .size () + "]"
136+ );
137+ }
138+ this .weights = weights ;
121139 }
122140
123141 public int rankConstant () {
124142 return rankConstant ;
125143 }
126144
145+ public float [] weights () {
146+ return weights ;
147+ }
148+
127149 @ Override
128150 public String getName () {
129151 return NAME ;
@@ -137,6 +159,7 @@ public ActionRequestValidationException validate(
137159 boolean allowPartialSearchResults
138160 ) {
139161 validationException = super .validate (source , validationException , isScroll , allowPartialSearchResults );
162+
140163 return MultiFieldsInnerRetrieverUtils .validateParams (
141164 innerRetrievers ,
142165 fields ,
@@ -151,7 +174,14 @@ public ActionRequestValidationException validate(
151174
152175 @ Override
153176 protected RRFRetrieverBuilder clone (List <RetrieverSource > newRetrievers , List <QueryBuilder > newPreFilterQueryBuilders ) {
154- RRFRetrieverBuilder clone = new RRFRetrieverBuilder (newRetrievers , this .fields , this .query , this .rankWindowSize , this .rankConstant );
177+ RRFRetrieverBuilder clone = new RRFRetrieverBuilder (
178+ newRetrievers ,
179+ this .fields ,
180+ this .query ,
181+ this .rankWindowSize ,
182+ this .rankConstant ,
183+ this .weights
184+ );
155185 clone .preFilterQueryBuilders = newPreFilterQueryBuilders ;
156186 clone .retrieverName = retrieverName ;
157187 return clone ;
@@ -183,7 +213,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
183213
184214 // calculate the current rrf score for this document
185215 // later used to sort and covert to a rank
186- value .score += 1.0f / (rankConstant + frank );
216+ value .score += this . weights [ findex ] * ( 1.0f / (rankConstant + frank ) );
187217
188218 if (explain && value .positions != null && value .scores != null ) {
189219 // record the position for each query
@@ -238,10 +268,14 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
238268 query ,
239269 localIndicesMetadata .values (),
240270 r -> {
241- List <RetrieverSource > retrievers = r .stream ()
242- .map (MultiFieldsInnerRetrieverUtils .WeightedRetrieverSource ::retrieverSource )
243- .toList ();
244- return new RRFRetrieverBuilder (retrievers , rankWindowSize , rankConstant );
271+ List <RetrieverSource > retrievers = new ArrayList <>(r .size ());
272+ float [] weights = new float [r .size ()];
273+ for (int i = 0 ; i < r .size (); i ++) {
274+ var retriever = r .get (i );
275+ retrievers .add (retriever .retrieverSource ());
276+ weights [i ] = retriever .weight ();
277+ }
278+ return new RRFRetrieverBuilder (retrievers , null , null , rankWindowSize , rankConstant , weights );
245279 },
246280 w -> {
247281 if (w != 1.0f ) {
@@ -255,7 +289,8 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
255289 if (fieldsInnerRetrievers .isEmpty () == false ) {
256290 // TODO: This is a incomplete solution as it does not address other incomplete copy issues
257291 // (such as dropping the retriever name and min score)
258- rewritten = new RRFRetrieverBuilder (fieldsInnerRetrievers , rankWindowSize , rankConstant );
292+ float [] weights = createDefaultWeights (fieldsInnerRetrievers );
293+ rewritten = new RRFRetrieverBuilder (fieldsInnerRetrievers , null , null , rankWindowSize , rankConstant , weights );
259294 rewritten .getPreFilterQueryBuilders ().addAll (preFilterQueryBuilders );
260295 } else {
261296 // Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
@@ -266,29 +301,13 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
266301 return rewritten ;
267302 }
268303
269- // ---- FOR TESTING XCONTENT PARSING ----
270-
271- @ Override
272- public boolean doEquals (Object o ) {
273- RRFRetrieverBuilder that = (RRFRetrieverBuilder ) o ;
274- return super .doEquals (o )
275- && Objects .equals (fields , that .fields )
276- && Objects .equals (query , that .query )
277- && rankConstant == that .rankConstant ;
278- }
279-
280- @ Override
281- public int doHashCode () {
282- return Objects .hash (super .doHashCode (), fields , query , rankConstant );
283- }
284-
285304 @ Override
286305 public void doToXContent (XContentBuilder builder , Params params ) throws IOException {
287306 if (innerRetrievers .isEmpty () == false ) {
288307 builder .startArray (RETRIEVERS_FIELD .getPreferredName ());
289-
290- for ( var entry : innerRetrievers ) {
291- entry . retriever () .toXContent (builder , params );
308+ for ( int i = 0 ; i < innerRetrievers . size (); i ++) {
309+ RRFRetrieverComponent component = new RRFRetrieverComponent ( innerRetrievers . get ( i ). retriever (), weights [ i ]);
310+ component .toXContent (builder , params );
292311 }
293312 builder .endArray ();
294313 }
@@ -307,4 +326,20 @@ public void doToXContent(XContentBuilder builder, Params params) throws IOExcept
307326 builder .field (RANK_WINDOW_SIZE_FIELD .getPreferredName (), rankWindowSize );
308327 builder .field (RANK_CONSTANT_FIELD .getPreferredName (), rankConstant );
309328 }
329+
330+ // ---- FOR TESTING XCONTENT PARSING ----
331+ @ Override
332+ public boolean doEquals (Object o ) {
333+ RRFRetrieverBuilder that = (RRFRetrieverBuilder ) o ;
334+ return super .doEquals (o )
335+ && Objects .equals (fields , that .fields )
336+ && Objects .equals (query , that .query )
337+ && rankConstant == that .rankConstant
338+ && Arrays .equals (weights , that .weights );
339+ }
340+
341+ @ Override
342+ public int doHashCode () {
343+ return Objects .hash (super .doHashCode (), fields , query , rankConstant , Arrays .hashCode (weights ));
344+ }
310345}
0 commit comments