Skip to content

Commit a8f6487

Browse files
committed
Modified parser, toXcontent and included component in the RetrieverBuilder
1 parent 8be20f0 commit a8f6487

File tree

2 files changed

+88
-33
lines changed

2 files changed

+88
-33
lines changed

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 86 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
import java.util.Map;
3838
import java.util.Objects;
3939

40+
import static org.elasticsearch.action.ValidateActions.addValidationError;
4041
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
42+
import static org.elasticsearch.xpack.rank.rrf.RRFRetrieverComponent.DEFAULT_WEIGHT;
4143

4244
/**
4345
* An rrf retriever is used to represent an rrf rank element, but
@@ -57,33 +59,63 @@ public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetri
5759
public static final ParseField QUERY_FIELD = new ParseField("query");
5860

5961
public static final int DEFAULT_RANK_CONSTANT = 60;
62+
63+
private final float[] weights;
64+
6065
@SuppressWarnings("unchecked")
6166
static final ConstructingObjectParser<RRFRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
6267
NAME,
6368
false,
6469
args -> {
65-
List<RetrieverBuilder> childRetrievers = (List<RetrieverBuilder>) args[0];
70+
List<Object> rawRetrievers = args[0] == null ? List.of() : (List<Object>) args[0];
6671
List<String> fields = (List<String>) args[1];
6772
String query = (String) args[2];
6873
int rankWindowSize = args[3] == null ? RankBuilder.DEFAULT_RANK_WINDOW_SIZE : (int) args[3];
6974
int rankConstant = args[4] == null ? DEFAULT_RANK_CONSTANT : (int) args[4];
7075

71-
List<RetrieverSource> innerRetrievers = childRetrievers != null
72-
? childRetrievers.stream().map(RetrieverSource::from).toList()
73-
: List.of();
74-
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant);
76+
List<RetrieverSource> innerRetrievers = new ArrayList<>(rawRetrievers.size());
77+
float[] weights = new float[rawRetrievers.size()];
78+
79+
int weightIndex = 0;
80+
for (Object retrieverOrComponent : rawRetrievers) {
81+
if (retrieverOrComponent instanceof RRFRetrieverComponent component) {
82+
innerRetrievers.add(RetrieverSource.from(component.retriever));
83+
weights[weightIndex++] = component.weight;
84+
} else {
85+
RetrieverBuilder bareRetriever = (RetrieverBuilder) retrieverOrComponent;
86+
innerRetrievers.add(RetrieverSource.from(bareRetriever));
87+
weights[weightIndex++] = RRFRetrieverComponent.DEFAULT_WEIGHT;
88+
}
89+
}
90+
91+
return new RRFRetrieverBuilder(innerRetrievers, fields, query, rankWindowSize, rankConstant, weights);
7592
}
7693
);
7794

7895
static {
79-
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> {
80-
p.nextToken();
81-
String name = p.currentName();
82-
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
83-
c.trackRetrieverUsage(retrieverBuilder.getName());
84-
p.nextToken();
85-
return retrieverBuilder;
86-
}, RETRIEVERS_FIELD);
96+
PARSER.declareObjectArray(optionalConstructorArg(),
97+
(p, c) -> {
98+
List<Object> list = new ArrayList<>();
99+
while (p.nextToken() != XContentParser.Token.END_ARRAY) {
100+
if (p.currentToken() == XContentParser.Token.START_OBJECT &&
101+
p.nextToken() == XContentParser.Token.FIELD_NAME &&
102+
RRFRetrieverComponent.RETRIEVER_FIELD.match(p.currentName(), p.getDeprecationHandler())) {
103+
// Handle wrapped retriever with weight
104+
list.add(RRFRetrieverComponent.fromXContent(p, c));
105+
} else {
106+
// Handle bare retriever (legacy format)
107+
String name = p.currentName();
108+
RetrieverBuilder retrieverBuilder = p.namedObject(RetrieverBuilder.class, name, c);
109+
c.trackRetrieverUsage(retrieverBuilder.getName());
110+
p.nextToken();
111+
list.add(retrieverBuilder);
112+
}
113+
}
114+
return list;
115+
},
116+
RETRIEVERS_FIELD
117+
);
118+
87119
PARSER.declareStringArray(optionalConstructorArg(), FIELDS_FIELD);
88120
PARSER.declareString(optionalConstructorArg(), QUERY_FIELD);
89121
PARSER.declareInt(optionalConstructorArg(), RANK_WINDOW_SIZE_FIELD);
@@ -103,21 +135,30 @@ public static RRFRetrieverBuilder fromXContent(XContentParser parser, RetrieverP
103135
private final int rankConstant;
104136

105137
public RRFRetrieverBuilder(List<RetrieverSource> childRetrievers, int rankWindowSize, int rankConstant) {
106-
this(childRetrievers, null, null, rankWindowSize, rankConstant);
138+
this(childRetrievers, null, null, rankWindowSize, rankConstant, createDefaultWeights(childRetrievers));
139+
}
140+
141+
private static float[] createDefaultWeights(List<RetrieverSource> retrievers) {
142+
int size = retrievers == null ? 0 : retrievers.size();
143+
float[] defaultWeights = new float[size];
144+
Arrays.fill(defaultWeights, DEFAULT_WEIGHT);
145+
return defaultWeights;
107146
}
108147

109148
public RRFRetrieverBuilder(
110149
List<RetrieverSource> childRetrievers,
111150
List<String> fields,
112151
String query,
113152
int rankWindowSize,
114-
int rankConstant
153+
int rankConstant,
154+
float[] weights
115155
) {
116156
// Use a mutable list for childRetrievers so that we can use addChild
117157
super(childRetrievers == null ? new ArrayList<>() : new ArrayList<>(childRetrievers), rankWindowSize);
118158
this.fields = fields == null ? null : List.copyOf(fields);
119159
this.query = query;
120160
this.rankConstant = rankConstant;
161+
this.weights = weights;
121162
}
122163

123164
public int rankConstant() {
@@ -137,6 +178,14 @@ public ActionRequestValidationException validate(
137178
boolean allowPartialSearchResults
138179
) {
139180
validationException = super.validate(source, validationException, isScroll, allowPartialSearchResults);
181+
182+
if (this.weights != null) {
183+
for (float weight : this.weights) {
184+
if (weight < 0) {
185+
validationException = addValidationError("[weight] must be non-negative, found [" + weight + "]", validationException);
186+
}
187+
}
188+
}
140189
return MultiFieldsInnerRetrieverUtils.validateParams(
141190
innerRetrievers,
142191
fields,
@@ -151,7 +200,7 @@ public ActionRequestValidationException validate(
151200

152201
@Override
153202
protected RRFRetrieverBuilder clone(List<RetrieverSource> newRetrievers, List<QueryBuilder> newPreFilterQueryBuilders) {
154-
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant);
203+
RRFRetrieverBuilder clone = new RRFRetrieverBuilder(newRetrievers, this.fields, this.query, this.rankWindowSize, this.rankConstant, this.weights);
155204
clone.preFilterQueryBuilders = newPreFilterQueryBuilders;
156205
clone.retrieverName = retrieverName;
157206
return clone;
@@ -183,7 +232,7 @@ protected RRFRankDoc[] combineInnerRetrieverResults(List<ScoreDoc[]> rankResults
183232

184233
// calculate the current rrf score for this document
185234
// later used to sort and covert to a rank
186-
value.score += 1.0f / (rankConstant + frank);
235+
value.score += this.weights[findex] * (1.0f / (rankConstant + frank));
187236

188237
if (explain && value.positions != null && value.scores != null) {
189238
// record the position for each query
@@ -233,29 +282,34 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
233282
);
234283
}
235284

236-
List<RetrieverSource> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
285+
List<RetrieverBuilder> fieldsInnerRetrievers = MultiFieldsInnerRetrieverUtils.generateInnerRetrievers(
237286
fields,
238287
query,
239288
localIndicesMetadata.values(),
240289
r -> {
241-
List<RetrieverSource> retrievers = r.stream()
242-
.map(MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource::retrieverSource)
243-
.toList();
244-
return new RRFRetrieverBuilder(retrievers, rankWindowSize, rankConstant);
290+
List<RetrieverSource> retrievers = new ArrayList<>(r.size());
291+
float[] weights = new float[r.size()];
292+
int i = 0;
293+
for(var retriever: r) {
294+
retrievers.add(retriever.retrieverSource());
295+
weights[i++] = retriever.weight();
296+
}
297+
return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
245298
},
246299
w -> {
247-
if (w != 1.0f) {
300+
if (w < 0) {
248301
throw new IllegalArgumentException(
249-
"[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]"
302+
"[" + NAME + "] per-field weights must be non-negative"
250303
);
251304
}
252305
}
253-
).stream().map(RetrieverSource::from).toList();
306+
);
254307

255308
if (fieldsInnerRetrievers.isEmpty() == false) {
256309
// TODO: This is a incomplete solution as it does not address other incomplete copy issues
257310
// (such as dropping the retriever name and min score)
258-
rewritten = new RRFRetrieverBuilder(fieldsInnerRetrievers, rankWindowSize, rankConstant);
311+
RRFRetrieverBuilder g = (RRFRetrieverBuilder) fieldsInnerRetrievers.get(0);
312+
rewritten = new RRFRetrieverBuilder(g.innerRetrievers, null, null, rankWindowSize, rankConstant, g.weights);
259313
rewritten.getPreFilterQueryBuilders().addAll(preFilterQueryBuilders);
260314
} else {
261315
// Inner retriever list can be empty when using an index wildcard pattern that doesn't match any indices
@@ -274,21 +328,22 @@ public boolean doEquals(Object o) {
274328
return super.doEquals(o)
275329
&& Objects.equals(fields, that.fields)
276330
&& Objects.equals(query, that.query)
277-
&& rankConstant == that.rankConstant;
331+
&& rankConstant == that.rankConstant
332+
&& Arrays.equals(weights, that.weights);
278333
}
279334

280335
@Override
281336
public int doHashCode() {
282-
return Objects.hash(super.doHashCode(), fields, query, rankConstant);
337+
return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights));
283338
}
284339

285340
@Override
286341
public void doToXContent(XContentBuilder builder, Params params) throws IOException {
287342
if (innerRetrievers.isEmpty() == false) {
288343
builder.startArray(RETRIEVERS_FIELD.getPreferredName());
289-
290-
for (var entry : innerRetrievers) {
291-
entry.retriever().toXContent(builder, params);
344+
345+
for (int i = 0; i < innerRetrievers.size(); i++) {
346+
new RRFRetrieverComponent(innerRetrievers.get(i).retriever(), this.weights[i]).toXContent(builder, params);
292347
}
293348
builder.endArray();
294349
}

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverComponent.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ public class RRFRetrieverComponent implements ToXContentObject {
2020

2121
static final float DEFAULT_WEIGHT = 1f;
2222

23-
RetrieverBuilder retriever;
24-
float weight;
23+
final RetrieverBuilder retriever;
24+
final float weight;
2525

2626
public RRFRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight) {
2727
assert retrieverBuilder != null;

0 commit comments

Comments
 (0)