Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.compute.operator;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.compute.data.Block;
import org.elasticsearch.compute.data.BytesRefBlock;
import org.elasticsearch.compute.data.DoubleVector;
import org.elasticsearch.compute.data.DoubleVectorBlock;
import org.elasticsearch.compute.data.Page;

import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

public class LinearScoreEvalOperator implements Operator {
public record Factory(int scorePosition, int discriminatorPosition, Map<String, Double> weights, String normalizer)
implements
OperatorFactory {

@Override
public Operator get(DriverContext driverContext) {
return new LinearScoreEvalOperator(scorePosition, discriminatorPosition, weights, normalizer);
}

@Override
public String describe() {
return "LinearScoreEvalOperator";
}
}

private final int scorePosition;
private final int discriminatorPosition;
private final Map<String, Double> weights;
private final Normalizer normalizer;

private final Deque<Page> inputPages;
private final Deque<Page> outputPages;
private boolean finished;

public LinearScoreEvalOperator(int scorePosition, int discriminatorPosition, Map<String, Double> weights, String normalizerType) {
this.scorePosition = scorePosition;
this.discriminatorPosition = discriminatorPosition;
this.weights = weights;
this.normalizer = createNormalizer(normalizerType);

finished = false;
inputPages = new ArrayDeque<>();
outputPages = new ArrayDeque<>();
}

@Override
public boolean needsInput() {
return finished == false;
}

@Override
public void addInput(Page page) {
inputPages.add(page);
}

@Override
public void finish() {
if (finished == false) {
finished = true;
createOutputPages();
}
}

private void createOutputPages() {
normalizer.preprocess(inputPages, scorePosition, discriminatorPosition);

while (inputPages.isEmpty() == false) {
Page inputPage = inputPages.peek();

BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);
DoubleVectorBlock initialScoreBlock = inputPage.getBlock(scorePosition);

DoubleVector.Builder scores = discriminatorBlock.blockFactory().newDoubleVectorBuilder(discriminatorBlock.getPositionCount());

for (int i = 0; i < inputPage.getPositionCount(); i++) {
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();

var weight = weights.get(discriminator) == null ? 1.0 : weights.get(discriminator);

Double score = initialScoreBlock.getDouble(i);
scores.appendDouble(weight * normalizer.normalize(score, discriminator));
}
Block scoreBlock = scores.build().asBlock();
inputPage = inputPage.appendBlock(scoreBlock);

int[] projections = new int[inputPage.getBlockCount() - 1];

for (int i = 0; i < inputPage.getBlockCount() - 1; i++) {
projections[i] = i == scorePosition ? inputPage.getBlockCount() - 1 : i;
}
inputPages.removeFirst();
outputPages.add(inputPage.projectBlocks(projections));
inputPage.releaseBlocks();
}
}

@Override
public boolean isFinished() {
return finished && outputPages.isEmpty();
}

@Override
public Page getOutput() {
if (finished == false || outputPages.isEmpty()) {
return null;
}
return outputPages.removeFirst();
}

@Override
public void close() {
for (Page page : inputPages) {
page.releaseBlocks();
}
for (Page page : outputPages) {
page.releaseBlocks();
}
}

@Override
public String toString() {
return "LinearScoreEvalOperator[" + scorePosition + ", " + discriminatorPosition + "]";
}

private Normalizer createNormalizer(String normalizer) {
return switch (normalizer.toLowerCase(Locale.ROOT)) {
case "minmax" -> new MinMaxNormalizer();
case "l2_norm" -> new L2NormNormalizer();
case "none" -> new NoneNormalizer();
case null -> new NoneNormalizer();
default -> throw new IllegalArgumentException("Unknown normalizer: " + normalizer);
};
}

private abstract class Normalizer {
public abstract double normalize(double score, String discriminator);

public abstract void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition);
}

private class NoneNormalizer extends Normalizer {
@Override
public double normalize(double score, String discriminator) {
return score;
}

@Override
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {}
}

private class L2NormNormalizer extends Normalizer {
private Map<String, Double> l2Norms = new HashMap<>();

@Override
public double normalize(double score, String discriminator) {
var l2Norm = l2Norms.get(discriminator);
assert l2Norm != null;
return l2Norms.get(discriminator) == 0.0 ? 0.0 : score / l2Norm;
}

@Override
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
for (Page inputPage : inputPages) {
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);

for (int i = 0; i < inputPage.getPositionCount(); i++) {
double score = scoreBlock.getDouble(i);
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();

l2Norms.compute(discriminator, (k, v) -> v == null ? score * score : v + score * score);
}
}

l2Norms.replaceAll((k, v) -> Math.sqrt(v));
}
}

private class MinMaxNormalizer extends Normalizer {
private Map<String, Double> minScores = new HashMap<>();
private Map<String, Double> maxScores = new HashMap<>();

@Override
public double normalize(double score, String discriminator) {
var min = minScores.get(discriminator);
var max = maxScores.get(discriminator);

assert min != null;
assert max != null;

if (min.equals(max)) {
return 0.0;
}

return (score - min) / (max - min);
}

@Override
public void preprocess(Collection<Page> inputPages, int scorePosition, int discriminatorPosition) {
for (Page inputPage : inputPages) {
DoubleVectorBlock scoreBlock = inputPage.getBlock(scorePosition);
BytesRefBlock discriminatorBlock = inputPage.getBlock(discriminatorPosition);

for (int i = 0; i < inputPage.getPositionCount(); i++) {
double score = scoreBlock.getDouble(i);
String discriminator = discriminatorBlock.getBytesRef(i, new BytesRef()).utf8ToString();

minScores.compute(discriminator, (key, value) -> value == null ? score : Math.min(value, score));
maxScores.compute(discriminator, (key, value) -> value == null ? score : Math.max(value, score));
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.compute.data.Page;

import java.util.HashMap;
import java.util.Map;

/**
* Updates the score column with new scores using the RRF formula.
Expand All @@ -23,10 +24,19 @@
*/
public class RrfScoreEvalOperator extends AbstractPageMappingOperator {

public record Factory(int forkPosition, int scorePosition) implements OperatorFactory {
private final double rankConstant;
private final Map<String, Double> weights;
private final int scorePosition;
private final int forkPosition;

private HashMap<String, Integer> counters = new HashMap<>();

public record Factory(int forkPosition, int scorePosition, double rankConstant, Map<String, Double> weights)
implements
OperatorFactory {
@Override
public Operator get(DriverContext driverContext) {
return new RrfScoreEvalOperator(forkPosition, scorePosition);
return new RrfScoreEvalOperator(forkPosition, scorePosition, rankConstant, weights);
}

@Override
Expand All @@ -36,14 +46,11 @@ public String describe() {

}

private final int scorePosition;
private final int forkPosition;

private HashMap<String, Integer> counters = new HashMap<>();

public RrfScoreEvalOperator(int forkPosition, int scorePosition) {
public RrfScoreEvalOperator(int forkPosition, int scorePosition, double rankConstant, Map<String, Double> weights) {
this.scorePosition = scorePosition;
this.forkPosition = forkPosition;
this.rankConstant = rankConstant;
this.weights = weights;
}

@Override
Expand All @@ -57,7 +64,10 @@ protected Page process(Page page) {

int rank = counters.getOrDefault(fork, 1);
counters.put(fork, rank + 1);
scores.appendDouble(1.0 / (60 + rank));

var weight = weights.get(fork) == null ? 1.0 : weights.get(fork);

scores.appendDouble(1.0 / (rankConstant + rank) * weight);
}

Block scoreBlock = scores.build().asBlock();
Expand Down
Loading