Skip to content

Commit 1993e23

Browse files
mridula-s109elasticsearchmachine
authored andcommitted
Integrate weights into simplified RRF retriever syntax (elastic#132680)
* Made changes to include simplified weights to PR: * Add basic changes to include the feature * implemented changes * [CI] Auto commit changes from spotless * Work in progress * WIP * WIP * Modified changes to include the simplified rrf * Clean the mess * Fixed the failing tests * Removed the it nit* * refactored component * Fixed issues * Improved the parsing tests * Modified code * Refactored code' * Update and rename 133400.yaml to 132680.yaml * cleaned up code * Merge conflicts * Update LinearRetrieverBuilderTests.java * RRFBuilder checkstyle issue * fix failing test * [CI] Auto commit changes from spotless * Resolved merge conflict * [CI] Auto commit changes from spotless * Parsing and yaml changes * Cleaned uo the builder * cleanedup * Unnecessary comments * cleaned up * Cleanup * Cleanup * removed duplicate --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 84794b4 commit 1993e23

File tree

6 files changed

+369
-34
lines changed

6 files changed

+369
-34
lines changed

docs/changelog/132680.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 132680
2+
summary: Add support for per-field weights in simplified RRF retriever syntax
3+
area: Search
4+
type: enhancement
5+
issues: []

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/RankRRFFeatures.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public Set<NodeFeature> getTestFeatures() {
3939
LinearRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
4040
RRFRetrieverBuilder.MULTI_FIELDS_QUERY_FORMAT_SUPPORT,
4141
RRFRetrieverBuilder.WEIGHTED_SUPPORT,
42+
RRFRetrieverBuilder.SIMPLIFIED_WEIGHTED_SUPPORT,
4243
LINEAR_RETRIEVER_TOP_LEVEL_NORMALIZER,
4344
LinearRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT,
4445
RRFRetrieverBuilder.MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT

x-pack/plugin/rank-rrf/src/main/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilder.java

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.elasticsearch.xcontent.XContentParser;
3131
import org.elasticsearch.xpack.core.XPackPlugin;
3232
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils;
33+
import org.elasticsearch.xpack.rank.MultiFieldsInnerRetrieverUtils.WeightedRetrieverSource;
3334

3435
import java.io.IOException;
3536
import java.util.ArrayList;
@@ -46,10 +47,14 @@
4647
* meaning it has a set of child retrievers that each return a set of
4748
* top docs that will then be combined and ranked according to the rrf
4849
* formula.
50+
*
51+
* Supports both explicit retriever configuration and simplified field-based
52+
* syntax with optional per-field weights (e.g., "field^2.0").
4953
*/
5054
public final class RRFRetrieverBuilder extends CompoundRetrieverBuilder<RRFRetrieverBuilder> {
5155
public static final NodeFeature MULTI_FIELDS_QUERY_FORMAT_SUPPORT = new NodeFeature("rrf_retriever.multi_fields_query_format_support");
5256
public static final NodeFeature WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.weighted_support");
57+
public static final NodeFeature SIMPLIFIED_WEIGHTED_SUPPORT = new NodeFeature("rrf_retriever.simplified_weighted_support");
5358
public static final NodeFeature MULTI_INDEX_SIMPLIFIED_FORMAT_SUPPORT = new NodeFeature(
5459
"rrf_retriever.multi_index_simplified_format_support"
5560
);
@@ -265,23 +270,8 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
265270
fields,
266271
query,
267272
localIndicesMetadata.values(),
268-
r -> {
269-
List<RetrieverSource> retrievers = new ArrayList<>(r.size());
270-
float[] weights = new float[r.size()];
271-
for (int i = 0; i < r.size(); i++) {
272-
var retriever = r.get(i);
273-
retrievers.add(retriever.retrieverSource());
274-
weights[i] = retriever.weight();
275-
}
276-
return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
277-
},
278-
w -> {
279-
if (w != 1.0f) {
280-
throw new IllegalArgumentException(
281-
"[" + NAME + "] does not support per-field weights in [" + FIELDS_FIELD.getPreferredName() + "]"
282-
);
283-
}
284-
}
273+
r -> createRRFFromWeightedRetrievers(r, rankWindowSize, rankConstant),
274+
w -> validateNonNegativeWeight(w)
285275
).stream().map(RetrieverSource::from).toList();
286276

287277
if (fieldsInnerRetrievers.isEmpty() == false) {
@@ -295,7 +285,6 @@ protected RetrieverBuilder doRewrite(QueryRewriteContext ctx) {
295285
rewritten = new StandardRetrieverBuilder(new MatchNoneQueryBuilder());
296286
}
297287
}
298-
299288
return rewritten;
300289
}
301290

@@ -340,4 +329,26 @@ public boolean doEquals(Object o) {
340329
public int doHashCode() {
341330
return Objects.hash(super.doHashCode(), fields, query, rankConstant, Arrays.hashCode(weights));
342331
}
332+
333+
private static RRFRetrieverBuilder createRRFFromWeightedRetrievers(
334+
List<WeightedRetrieverSource> r,
335+
int rankWindowSize,
336+
int rankConstant
337+
) {
338+
int size = r.size();
339+
List<RetrieverSource> retrievers = new ArrayList<>(size);
340+
float[] weights = new float[size];
341+
for (int i = 0; i < size; i++) {
342+
var retriever = r.get(i);
343+
retrievers.add(retriever.retrieverSource());
344+
weights[i] = retriever.weight();
345+
}
346+
return new RRFRetrieverBuilder(retrievers, null, null, rankWindowSize, rankConstant, weights);
347+
}
348+
349+
private static void validateNonNegativeWeight(float w) {
350+
if (w < 0) {
351+
throw new IllegalArgumentException("[" + NAME + "] per-field weights must be non-negative");
352+
}
353+
}
343354
}

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderParsingTests.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,14 @@ public static RRFRetrieverBuilder createRandomRRFRetrieverBuilder() {
5151
List<String> fields = null;
5252
String query = null;
5353
if (randomBoolean()) {
54-
fields = randomList(1, 10, () -> randomAlphaOfLengthBetween(1, 10));
54+
fields = randomList(1, 10, () -> {
55+
String field = randomAlphaOfLengthBetween(1, 10);
56+
if (randomBoolean()) {
57+
float weight = randomFloatBetween(0.0f, 10.1f, true);
58+
field = field + "^" + weight;
59+
}
60+
return field;
61+
});
5562
query = randomAlphaOfLengthBetween(1, 10);
5663
}
5764

@@ -359,6 +366,36 @@ public void testRRFRetrieverComponentErrorCases() throws IOException {
359366
expectParsingException(retrieverAsStringContent, "retriever must be an object");
360367
}
361368

369+
public void testSimplifiedWeightedFieldsParsing() throws IOException {
370+
String restContent = """
371+
{
372+
"retriever": {
373+
"rrf": {
374+
"retrievers": [
375+
{
376+
"test": {
377+
"value": "foo"
378+
}
379+
},
380+
{
381+
"test": {
382+
"value": "bar"
383+
}
384+
}
385+
],
386+
"fields": ["name^2.0", "description^0.5"],
387+
"query": "test",
388+
"rank_window_size": 100,
389+
"rank_constant": 10,
390+
"min_score": 20.0,
391+
"_name": "foo_rrf"
392+
}
393+
}
394+
}
395+
""";
396+
checkRRFRetrieverParsing(restContent);
397+
}
398+
362399
private void expectParsingException(String restContent, String expectedMessageFragment) throws IOException {
363400
SearchUsageHolder searchUsageHolder = new UsageService().getSearchUsageHolder();
364401
try (XContentParser jsonParser = createParser(JsonXContent.jsonXContent, restContent)) {

x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java

Lines changed: 137 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -200,11 +200,29 @@ public void testMultiFieldsParamsRewrite() {
200200
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f),
201201
"foo2"
202202
);
203+
}
203204

204-
// Glob matching on inference and non-inference fields
205-
rrfRetrieverBuilder = new RRFRetrieverBuilder(
205+
public void testMultiFieldsParamsRewriteWithWeights() {
206+
final String indexName = "test-index";
207+
final List<String> testInferenceFields = List.of("semantic_field_1", "semantic_field_2");
208+
final ResolvedIndices resolvedIndices = createMockResolvedIndices(Map.of(indexName, testInferenceFields), null, Map.of());
209+
final QueryRewriteContext queryRewriteContext = new QueryRewriteContext(
210+
parserConfig(),
211+
null,
206212
null,
207-
List.of("field_*", "*_field_1"),
213+
TransportVersion.current(),
214+
RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY,
215+
resolvedIndices,
216+
new PointInTimeBuilder(new BytesArray("pitid")),
217+
null,
218+
null,
219+
false
220+
);
221+
222+
// Simple per-field boosting
223+
RRFRetrieverBuilder rrfRetrieverBuilder = new RRFRetrieverBuilder(
224+
null,
225+
List.of("field_1", "field_2^1.5", "semantic_field_1", "semantic_field_2^2"),
208226
"bar",
209227
DEFAULT_RANK_WINDOW_SIZE,
210228
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
@@ -213,15 +231,16 @@ public void testMultiFieldsParamsRewrite() {
213231
assertMultiFieldsParamsRewrite(
214232
rrfRetrieverBuilder,
215233
queryRewriteContext,
216-
Map.of("field_*", 1.0f, "*_field_1", 1.0f),
217-
Map.of("semantic_field_1", 1.0f),
218-
"bar"
234+
Map.of("field_1", 1.0f, "field_2", 1.5f),
235+
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 2.0f),
236+
"bar",
237+
null
219238
);
220239

221-
// All-fields wildcard
240+
// Glob matching on inference and non-inference fields with per-field boosting
222241
rrfRetrieverBuilder = new RRFRetrieverBuilder(
223242
null,
224-
List.of("*"),
243+
List.of("field_*^1.5", "*_field_1^2.5"),
225244
"baz",
226245
DEFAULT_RANK_WINDOW_SIZE,
227246
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
@@ -230,10 +249,117 @@ public void testMultiFieldsParamsRewrite() {
230249
assertMultiFieldsParamsRewrite(
231250
rrfRetrieverBuilder,
232251
queryRewriteContext,
233-
Map.of("*", 1.0f),
234-
Map.of("semantic_field_1", 1.0f, "semantic_field_2", 1.0f),
235-
"baz"
252+
Map.of("field_*", 1.5f, "*_field_1", 2.5f),
253+
Map.of("semantic_field_1", 2.5f),
254+
"baz",
255+
null
256+
);
257+
258+
// Multiple boosts defined on the same field
259+
rrfRetrieverBuilder = new RRFRetrieverBuilder(
260+
null,
261+
List.of("field_*^1.5", "field_1^3.0", "*_field_1^2.5", "semantic_*^1.5"),
262+
"baz2",
263+
DEFAULT_RANK_WINDOW_SIZE,
264+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
265+
new float[0]
266+
);
267+
assertMultiFieldsParamsRewrite(
268+
rrfRetrieverBuilder,
269+
queryRewriteContext,
270+
Map.of("field_*", 1.5f, "field_1", 3.0f, "*_field_1", 2.5f, "semantic_*", 1.5f),
271+
Map.of("semantic_field_1", 3.75f, "semantic_field_2", 1.5f),
272+
"baz2",
273+
null
274+
);
275+
276+
// All-fields wildcard with weights
277+
rrfRetrieverBuilder = new RRFRetrieverBuilder(
278+
null,
279+
List.of("*^2.0"),
280+
"qux",
281+
DEFAULT_RANK_WINDOW_SIZE,
282+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
283+
new float[0]
284+
);
285+
assertMultiFieldsParamsRewrite(
286+
rrfRetrieverBuilder,
287+
queryRewriteContext,
288+
Map.of("*", 2.0f),
289+
Map.of("semantic_field_1", 2.0f, "semantic_field_2", 2.0f),
290+
"qux",
291+
null
292+
);
293+
294+
// Zero weights (testing that zero is allowed as non-negative)
295+
rrfRetrieverBuilder = new RRFRetrieverBuilder(
296+
null,
297+
List.of("field_1^0", "field_2^1.0"),
298+
"zero_test",
299+
DEFAULT_RANK_WINDOW_SIZE,
300+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
301+
new float[0]
302+
);
303+
assertMultiFieldsParamsRewrite(
304+
rrfRetrieverBuilder,
305+
queryRewriteContext,
306+
Map.of("field_1", 0.0f, "field_2", 1.0f),
307+
Map.of(),
308+
"zero_test",
309+
null
310+
);
311+
312+
// Mixed weighted and unweighted fields in simplified syntax
313+
rrfRetrieverBuilder = new RRFRetrieverBuilder(
314+
null,
315+
List.of("title^2.5", "content", "tags^1.5", "description"),
316+
"test query",
317+
DEFAULT_RANK_WINDOW_SIZE,
318+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
319+
new float[0]
320+
);
321+
assertMultiFieldsParamsRewrite(
322+
rrfRetrieverBuilder,
323+
queryRewriteContext,
324+
Map.of("title", 2.5f, "content", 1.0f, "tags", 1.5f, "description", 1.0f),
325+
Map.of(),
326+
"test query",
327+
null
328+
);
329+
330+
// Decimal weight precision handling
331+
rrfRetrieverBuilder = new RRFRetrieverBuilder(
332+
null,
333+
List.of("field1^0.1", "field2^2.75", "field3^10.999"),
334+
"test query",
335+
DEFAULT_RANK_WINDOW_SIZE,
336+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
337+
new float[0]
338+
);
339+
assertMultiFieldsParamsRewrite(
340+
rrfRetrieverBuilder,
341+
queryRewriteContext,
342+
Map.of("field1", 0.1f, "field2", 2.75f, "field3", 10.999f),
343+
Map.of(),
344+
"test query",
345+
null
346+
);
347+
348+
// Test negative weight validation
349+
RRFRetrieverBuilder negativeWeightBuilder = new RRFRetrieverBuilder(
350+
null,
351+
List.of("field_1^-1.0"),
352+
"negative_test",
353+
DEFAULT_RANK_WINDOW_SIZE,
354+
RRFRetrieverBuilder.DEFAULT_RANK_CONSTANT,
355+
new float[0]
356+
);
357+
358+
IllegalArgumentException iae = expectThrows(
359+
IllegalArgumentException.class,
360+
() -> negativeWeightBuilder.doRewrite(queryRewriteContext)
236361
);
362+
assertEquals("[rrf] per-field weights must be non-negative", iae.getMessage());
237363
}
238364

239365
public void testMultiIndexMultiFieldsParamsRewrite() {

0 commit comments

Comments
 (0)