Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.compute.data.Page;

import java.util.HashMap;
import java.util.Map;

/**
* Updates the score column with new scores using the RRF formula.
Expand All @@ -23,10 +24,12 @@
*/
public class RrfScoreEvalOperator extends AbstractPageMappingOperator {

public record Factory(int forkPosition, int scorePosition) implements OperatorFactory {
public record Factory(int forkPosition, int scorePosition, double rankConstant, Map<String, Double> weights)
implements
OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
return new RrfScoreEvalOperator(forkPosition, scorePosition);
return new RrfScoreEvalOperator(forkPosition, scorePosition, rankConstant, weights);
}

@Override
Expand All @@ -38,26 +41,34 @@ public String describe() {

private final int scorePosition;
private final int forkPosition;
private final double rankConstant;
private final Map<String, Double> weights;

private HashMap<String, Integer> counters = new HashMap<>();

public RrfScoreEvalOperator(int forkPosition, int scorePosition) {
public RrfScoreEvalOperator(int forkPosition, int scorePosition, double rankConstant, Map<String, Double> weights) {
this.scorePosition = scorePosition;
this.forkPosition = forkPosition;
this.rankConstant = rankConstant;
this.weights = weights;
}

@Override
protected Page process(Page page) {
BytesRefBlock forkBlock = (BytesRefBlock) page.getBlock(forkPosition);
BytesRefBlock discriminatorBlock = (BytesRefBlock) page.getBlock(forkPosition);

DoubleVector.Builder scores = forkBlock.blockFactory().newDoubleVectorBuilder(forkBlock.getPositionCount());
DoubleVector.Builder scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount());

for (int i = 0; i < page.getPositionCount(); i++) {
String fork = forkBlock.getBytesRef(i, new BytesRef()).utf8ToString();
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();

int rank = counters.getOrDefault(fork, 1);
counters.put(fork, rank + 1);
scores.appendDouble(1.0 / (60 + rank));
int rank = counters.getOrDefault(discriminator, 1);
counters.put(discriminator, rank + 1);

var weight = weights.get(discriminator);
weight = weight == null ? 1 : weight;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ Should we use getOrDefault instead?


scores.appendDouble(1.0 / (this.rankConstant + rank) * weight);
}

Block scoreBlock = scores.build().asBlock();
Expand Down
108 changes: 102 additions & 6 deletions x-pack/plugin/esql/qa/testFixtures/src/main/resources/fuse.csv-spec
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

simpleFuse
required_capability: fork_v9
required_capability: fuse
required_capability: fuse_v2
required_capability: match_operator_colon

FROM employees METADATA _id, _index, _score
Expand All @@ -23,7 +23,7 @@ _score:double | _fork:keyword | emp_no:integer

fuseWithMatchAndScore
required_capability: fork_v9
required_capability: fuse
required_capability: fuse_v2
required_capability: match_operator_colon

FROM books METADATA _id, _index, _score
Expand All @@ -46,7 +46,7 @@ _score:double | _fork:keyword | _id:keyword

fuseWithDisjunctionAndPostFilter
required_capability: fork_v9
required_capability: fuse
required_capability: fuse_v2
required_capability: match_operator_colon

FROM books METADATA _id, _index, _score
Expand All @@ -69,7 +69,7 @@ _score:double | _fork:keyword | _id:keyword

fuseWithStats
required_capability: fork_v9
required_capability: fuse
required_capability: fuse_v2
required_capability: match_operator_colon

FROM books METADATA _id, _index, _score
Expand All @@ -89,7 +89,7 @@ count_fork:long | _fork:keyword

fuseWithMultipleForkBranches
required_capability: fork_v9
required_capability: fuse
required_capability: fuse_v2
required_capability: match_operator_colon

FROM books METADATA _id, _index, _score
Expand All @@ -116,7 +116,7 @@ _score:double | author:keyword | title:keyword | _fork

fuseWithSemanticSearch
required_capability: fork_v9
required_capability: fuse
required_capability: fuse_v2
required_capability: semantic_text_field_caps
required_capability: metadata_score

Expand All @@ -134,3 +134,99 @@ _fork:keyword | _score:double | _id:keyword | semantic_text_field:keyword
[fork1, fork2] | 0.0328 | 2 | all we have to decide is what to do with the time that is given to us
[fork1, fork2] | 0.0323 | 3 | be excellent to each other
;

fuseWithSimpleRrf
required_capability: fork_v9
required_capability: fuse_v2
required_capability: semantic_text_field_caps
required_capability: metadata_score

FROM books METADATA _id, _index, _score
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
| FUSE rrf
| SORT _score DESC, _id, _index
| EVAL _fork = mv_sort(_fork)
| EVAL _score = round(_score, 5)
| KEEP _score, _fork, _id
;

_score:double | _fork:keyword | _id:keyword
0.03279 | [fork1, fork2] | 4
0.01613 | fork1 | 56
0.01613 | fork2 | 60
0.01587 | fork2 | 1
0.01587 | fork1 | 26
;

fuseWithRrfAndRankConstant
required_capability: fork_v9
required_capability: fuse_v2
required_capability: semantic_text_field_caps
required_capability: metadata_score

FROM books METADATA _id, _index, _score
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
| FUSE rrf WITH {"rank_constant": 50 }
| SORT _score DESC, _id, _index
| EVAL _fork = mv_sort(_fork)
| EVAL _score = round(_score, 5)
| KEEP _score, _fork, _id
;

_score:double | _fork:keyword | _id:keyword
0.03922 | [fork1, fork2] | 4
0.01923 | fork1 | 56
0.01923 | fork2 | 60
0.01887 | fork2 | 1
0.01887 | fork1 | 26
;

fuseWithRrfAndWeights
required_capability: fork_v9
required_capability: fuse_v2
required_capability: semantic_text_field_caps
required_capability: metadata_score

FROM books METADATA _id, _index, _score
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
| FUSE rrf WITH {"weights": { "fork1": 0.3, "fork2": 0.7 } }
| SORT _score DESC, _id, _index
| EVAL _fork = mv_sort(_fork)
| EVAL _score = round(_score, 5)
| KEEP _score, _fork, _id
;

_score:double | _fork:keyword | _id:keyword
0.01639 | [fork1, fork2] | 4
0.01129 | fork2 | 60
0.01111 | fork2 | 1
0.00484 | fork1 | 56
0.00476 | fork1 | 26
;

fuseWithRrfRankConstantAndWeights
required_capability: fork_v9
required_capability: fuse_v2
required_capability: semantic_text_field_caps
required_capability: metadata_score

FROM books METADATA _id, _score, _index
| FORK ( WHERE title:"Tolkien" | SORT _score, _id DESC | LIMIT 3 )
( WHERE author:"Tolkien" | SORT _score, _id DESC | LIMIT 3)
| FUSE rrf WITH {"rank_constant": 60, "weights": { "fork1": 0.3, "fork2": 0.7 } }
| SORT _score DESC, _id, _index
| EVAL _fork = mv_sort(_fork)
| EVAL _score = round(_score, 5)
| KEEP _score, _fork, _id
;

_score:double | _fork:keyword | _id:keyword
0.01639 | [fork1, fork2] | 4
0.01129 | fork2 | 60
0.01111 | fork2 | 1
0.00484 | fork1 | 56
0.00476 | fork1 | 26
;
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ book_no:keyword | title:text | author

reranker after FUSE
required_capability: fork_v9
required_capability: fuse
required_capability: fuse_v2
required_capability: match_operator_colon
required_capability: rerank

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {

@Before
public void setupIndex() {
assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE.isEnabled());
assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V2.isEnabled());
createAndPopulateIndex();
}

Expand Down Expand Up @@ -58,6 +58,58 @@ public void testFuseWithRrf() throws Exception {
}
}

public void testFuseRrfWithWeights() {
var query = """
FROM test METADATA _score, _id, _index
| WHERE id > 2
| FORK
( WHERE content:"fox" | SORT _score, _id DESC )
( WHERE content:"dog" | SORT _score, _id DESC )
| FUSE RRF WITH {"weights": { "fork1": 0.4, "fork2": 0.6}}
| SORT _score DESC, _id, _index
| EVAL _fork = mv_sort(_fork)
| EVAL _score = round(_score, 4)
| KEEP id, content, _score, _fork
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "content", "_score", "_fork"));
assertColumnTypes(resp.columns(), List.of("integer", "keyword", "double", "keyword"));
assertThat(getValuesList(resp.values()).size(), equalTo(3));
Iterable<Iterable<Object>> expectedValues = List.of(
List.of(6, "The quick brown fox jumps over the lazy dog", 0.0162, List.of("fork1", "fork2")),
List.of(4, "The dog is brown but this document is very very long", 0.0098, "fork2"),
List.of(3, "This dog is really brown", 0.0095, "fork2")
);
assertValues(resp.values(), expectedValues);
}
}

public void testFuseRrfWithWeightsAndRankConstant() {
var query = """
FROM test METADATA _score, _id, _index
| WHERE id > 2
| FORK
( WHERE content:"fox" | SORT _score, _id DESC )
( WHERE content:"dog" | SORT _score, _id DESC )
| FUSE RRF WITH {"weights": { "fork1": 0.4, "fork2": 0.6}, "rank_constant": 55 }
| SORT _score DESC, _id, _index
| EVAL _fork = mv_sort(_fork)
| EVAL _score = round(_score, 4)
| KEEP id, content, _score, _fork
""";
try (var resp = run(query)) {
assertColumnNames(resp.columns(), List.of("id", "content", "_score", "_fork"));
assertColumnTypes(resp.columns(), List.of("integer", "keyword", "double", "keyword"));
assertThat(getValuesList(resp.values()).size(), equalTo(3));
Iterable<Iterable<Object>> expectedValues = List.of(
List.of(6, "The quick brown fox jumps over the lazy dog", 0.0177, List.of("fork1", "fork2")),
List.of(4, "The dog is brown but this document is very very long", 0.0107, "fork2"),
List.of(3, "This dog is really brown", 0.0103, "fork2")
);
assertValues(resp.values(), expectedValues);
}
}

private void createAndPopulateIndex() {
var indexName = "test";
var client = client().admin().indices();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ protected Collection<Class<? extends Plugin>> nodePlugins() {

@Before
public void setupIndex() {
assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE.isEnabled());
assumeTrue("requires FUSE capability", EsqlCapabilities.Cap.FUSE_V2.isEnabled());
var indexName = "test";
var client = client().admin().indices();
var CreateRequest = client.prepareCreate(indexName)
Expand Down
67 changes: 37 additions & 30 deletions x-pack/plugin/esql/src/main/antlr/EsqlBaseLexer.tokens

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion x-pack/plugin/esql/src/main/antlr/EsqlBaseParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -331,5 +331,10 @@ insistCommand
;

fuseCommand
: DEV_FUSE
: DEV_FUSE (fuseType=fuseMethod)? fuseOptions=commandNamedParameters
;

fuseMethod
: RRF
| LINEAR
;
Loading