Skip to content

Commit 8be20f0

Browse files
committed
RRFRetrieverComponent added:
1 parent 20ef955 commit 8be20f0

File tree

1 file changed

+66
-0
lines changed

1 file changed

+66
-0
lines changed
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package org.elasticsearch.xpack.rank.rrf;
2+
3+
import org.elasticsearch.search.retriever.RetrieverBuilder;
4+
import org.elasticsearch.search.retriever.RetrieverParserContext;
5+
import org.elasticsearch.xcontent.ConstructingObjectParser;
6+
import org.elasticsearch.xcontent.ParseField;
7+
import org.elasticsearch.xcontent.ToXContentObject;
8+
import org.elasticsearch.xcontent.XContentBuilder;
9+
import org.elasticsearch.xcontent.XContentParser;
10+
11+
import java.io.IOException;
12+
13+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
14+
import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg;
15+
16+
public class RRFRetrieverComponent implements ToXContentObject {
17+
18+
public static final ParseField RETRIEVER_FIELD = new ParseField("retriever");
19+
public static final ParseField WEIGHT_FIELD = new ParseField("weight");
20+
21+
static final float DEFAULT_WEIGHT = 1f;
22+
23+
RetrieverBuilder retriever;
24+
float weight;
25+
26+
public RRFRetrieverComponent(RetrieverBuilder retrieverBuilder, Float weight) {
27+
assert retrieverBuilder != null;
28+
this.retriever = retrieverBuilder;
29+
this.weight = weight == null ? DEFAULT_WEIGHT : weight;
30+
if (this.weight < 0) {
31+
throw new IllegalArgumentException("[weight] must be non-negative");
32+
}
33+
}
34+
35+
@Override
36+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
37+
builder.field(RETRIEVER_FIELD.getPreferredName(), retriever);
38+
builder.field(WEIGHT_FIELD.getPreferredName(), weight);
39+
return builder;
40+
}
41+
42+
43+
@SuppressWarnings("unchecked")
44+
static final ConstructingObjectParser<RRFRetrieverComponent, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
45+
"rrf_component",
46+
false,
47+
args -> {
48+
RetrieverBuilder retrieverBuilder = (RetrieverBuilder) args[0];
49+
Float weight = (Float) args[1];
50+
return new RRFRetrieverComponent(retrieverBuilder, weight);
51+
}
52+
);
53+
54+
static {
55+
PARSER.declareNamedObject(constructorArg(), (p, c, n) -> {
56+
RetrieverBuilder innerRetriever = p.namedObject(RetrieverBuilder.class, n, c);
57+
c.trackRetrieverUsage(innerRetriever.getName());
58+
return innerRetriever;
59+
}, RETRIEVER_FIELD);
60+
PARSER.declareFloat(optionalConstructorArg(), WEIGHT_FIELD);
61+
}
62+
63+
public static RRFRetrieverComponent fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
64+
return PARSER.apply(parser, context);
65+
}
66+
}

0 commit comments

Comments
 (0)