Skip to content

Commit 5c2c24d

Browse files
committed
Add FUSE command
1 parent 2a9f894 commit 5c2c24d

File tree

34 files changed

+3966
-2698
lines changed

34 files changed

+3966
-2698
lines changed
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.compute.operator;
9+
10+
import org.apache.lucene.util.BytesRef;
11+
import org.elasticsearch.compute.data.Block;
12+
import org.elasticsearch.compute.data.BytesRefBlock;
13+
import org.elasticsearch.compute.data.DoubleVector;
14+
import org.elasticsearch.compute.data.DoubleVectorBlock;
15+
import org.elasticsearch.compute.data.Page;
16+
17+
import java.util.ArrayDeque;
18+
import java.util.Collection;
19+
import java.util.Deque;
20+
import java.util.HashMap;
21+
import java.util.Map;
22+
23+
public class LinearScoreEvalOperator implements Operator {
24+
public record Factory(int scorePosition, int discriminatorPosition, Map<String, Double> weights, String normalizer)
25+
implements
26+
OperatorFactory {
27+
28+
@Override
29+
public Operator get(DriverContext driverContext) {
30+
return new LinearScoreEvalOperator(scorePosition, discriminatorPosition, weights, normalizer);
31+
}
32+
33+
@Override
34+
public String describe() {
35+
return "LinearScoreEvalOperator";
36+
}
37+
}
38+
39+
private final int scorePosition;
40+
private final int discriminatorPosition;
41+
private final Map<String, Double> weights;
42+
private final Normalizer normalizer;
43+
44+
private final Deque<Page> inputPages;
45+
private final Deque<Page> outputPages;
46+
private boolean finished;
47+
48+
public LinearScoreEvalOperator(int scorePosition, int discriminatorPosition, Map<String, Double> weights, String normalizerType) {
49+
this.scorePosition = scorePosition;
50+
this.discriminatorPosition = discriminatorPosition;
51+
this.weights = weights;
52+
this.normalizer = createNormalizer(normalizerType);
53+
54+
finished = false;
55+
inputPages = new ArrayDeque<>();
56+
outputPages = new ArrayDeque<>();
57+
}
58+
59+
@Override
60+
public boolean needsInput() {
61+
return finished == false;
62+
}
63+
64+
@Override
65+
public void addInput(Page page) {
66+
inputPages.add(page);
67+
}
68+
69+
@Override
70+
public void finish() {
71+
if (finished == false) {
72+
finished = true;
73+
createOutputPages();
74+
}
75+
}
76+
77+
private void createOutputPages() {
78+
normalizer.preprocess(inputPages, scorePosition, discriminatorPosition);
79+
80+
while (inputPages.isEmpty() == false) {
81+
Page inputPage = inputPages.peek();
82+
83+
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
84+
DoubleVectorBlock initialScoreBlock = inputPage.getBlock(scorePosition);
85+
86+
DoubleVector.Builder scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount());
87+
88+
for (int i = 0; i < inputPage.getPositionCount(); i++) {
89+
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
90+
91+
var weight = weights.get(discriminator) == null ? 1.0 : weights.get(discriminator);
92+
93+
Double score = initialScoreBlock.getDouble(i);
94+
scores.appendDouble(weight * normalizer.normalize(score, discriminator));
95+
}
96+
Block scoreBlock = scores.build().asBlock();
97+
inputPage = inputPage.appendBlock(scoreBlock);
98+
99+
int[] projections = new int[inputPage.getBlockCount() - 1];
100+
101+
for (int i = 0; i < inputPage.getBlockCount() - 1; i++) {
102+
projections[i] = i == scorePosition ? inputPage.getBlockCount() - 1 : i;
103+
}
104+
inputPages.removeFirst();
105+
outputPages.add(inputPage.projectBlocks(projections));
106+
inputPage.releaseBlocks();
107+
}
108+
}
109+
110+
@Override
111+
public boolean isFinished() {
112+
return finished && outputPages.isEmpty();
113+
}
114+
115+
@Override
116+
public Page getOutput() {
117+
if (finished == false || outputPages.isEmpty()) {
118+
return null;
119+
}
120+
return outputPages.removeFirst();
121+
}
122+
123+
@Override
124+
public void close() {
125+
for (Page page : inputPages) {
126+
page.releaseBlocks();
127+
}
128+
for (Page page : outputPages) {
129+
page.releaseBlocks();
130+
}
131+
}
132+
133+
@Override
134+
public String toString() {
135+
return "LinearScoreEvalOperator[" + scorePosition + ", " + discriminatorPosition + "]";
136+
}
137+
138+
private Normalizer createNormalizer(String normalizer) {
139+
return switch (normalizer.toLowerCase()) {
140+
case "minmax" -> new MinMaxNormalizer();
141+
case "l2_norm" -> new L2NormNormalizer();
142+
case "none" -> new NoneNormalizer();
143+
case null -> new NoneNormalizer();
144+
default -> throw new IllegalArgumentException("Unknown normalizer: " + normalizer);
145+
};
146+
}
147+
148+
private abstract class Normalizer {
149+
public abstract double normalize(double score, String discriminator);
150+
151+
public abstract void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition);
152+
}
153+
154+
private class NoneNormalizer extends Normalizer {
155+
@Override
156+
public double normalize(double score, String discriminator) {
157+
return score;
158+
}
159+
160+
@Override
161+
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {}
162+
}
163+
164+
private class L2NormNormalizer extends Normalizer {
165+
private Map<String, Double> l2Norms = new HashMap<>();
166+
167+
@Override
168+
public double normalize(double score, String discriminator) {
169+
var l2Norm = l2Norms.get(discriminator);
170+
assert l2Norm != null;
171+
return l2Norms.get(discriminator) == 0.0 ? 0.0 : score / l2Norm;
172+
}
173+
174+
@Override
175+
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
176+
for (Page inputPage : inputPages) {
177+
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
178+
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
179+
180+
for (int i = 0; i < inputPage.getPositionCount(); i++) {
181+
double score = scoreBlock.getDouble(i);
182+
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
183+
184+
l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score);
185+
}
186+
}
187+
188+
l2Norms.replaceAll((k, v) -> Math.sqrt(v));
189+
}
190+
}
191+
192+
private class MinMaxNormalizer extends Normalizer {
193+
private Map<String, Double> minScores = new HashMap<>();
194+
private Map<String, Double> maxScores = new HashMap<>();
195+
196+
@Override
197+
public double normalize(double score, String discriminator) {
198+
var min = minScores.get(discriminator);
199+
var max = maxScores.get(discriminator);
200+
201+
assert min != null;
202+
assert max != null;
203+
204+
if (min.equals(max)) {
205+
return 0.0;
206+
}
207+
208+
return (score - min) / (max - min);
209+
}
210+
211+
@Override
212+
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
213+
for (Page inputPage : inputPages) {
214+
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
215+
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
216+
217+
for (int i = 0; i < inputPage.getPositionCount(); i++) {
218+
double score = scoreBlock.getDouble(i);
219+
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();
220+
221+
minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score));
222+
maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score));
223+
}
224+
}
225+
}
226+
}
227+
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/RrfScoreEvalOperator.java

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.compute.data.Page;
1515

1616
import java.util.HashMap;
17+
import java.util.Map;
1718

1819
/**
1920
* Updates the score column with new scores using the RRF formula.
@@ -23,10 +24,19 @@
2324
*/
2425
public class RrfScoreEvalOperator extends AbstractPageMappingOperator {
2526

26-
public record Factory(int forkPosition, int scorePosition) implements OperatorFactory {
27+
private final double rankConstant;
28+
private final Map<String, Double> weights;
29+
private final int scorePosition;
30+
private final int forkPosition;
31+
32+
private HashMap<String, Integer> counters = new HashMap<>();
33+
34+
public record Factory(int forkPosition, int scorePosition, double rankConstant, Map<String, Double> weights)
35+
implements
36+
OperatorFactory {
2737
@Override
2838
public Operator get(DriverContext driverContext) {
29-
return new RrfScoreEvalOperator(forkPosition, scorePosition);
39+
return new RrfScoreEvalOperator(forkPosition, scorePosition, rankConstant, weights);
3040
}
3141

3242
@Override
@@ -36,14 +46,11 @@ public String describe() {
3646

3747
}
3848

39-
private final int scorePosition;
40-
private final int forkPosition;
41-
42-
private HashMap<String, Integer> counters = new HashMap<>();
43-
44-
public RrfScoreEvalOperator(int forkPosition, int scorePosition) {
49+
public RrfScoreEvalOperator(int forkPosition, int scorePosition, double rankConstant, Map<String, Double> weights) {
4550
this.scorePosition = scorePosition;
4651
this.forkPosition = forkPosition;
52+
this.rankConstant = rankConstant;
53+
this.weights = weights;
4754
}
4855

4956
@Override
@@ -57,7 +64,10 @@ protected Page process(Page page) {
5764

5865
int rank = counters.getOrDefault(fork, 1);
5966
counters.put(fork, rank + 1);
60-
scores.appendDouble(1.0 / (60 + rank));
67+
68+
var weight = weights.get(fork) == null ? 1.0 : weights.get(fork);
69+
70+
scores.appendDouble(1.0 / (rankConstant + rank) * weight);
6171
}
6272

6373
Block scoreBlock = scores.build().asBlock();

0 commit comments

Comments
 (0)