Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125694.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125694
summary: LTR score bounding
area: Ranking
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;

public interface BoundedInferenceModel extends InferenceModel {
double getMinPredictedValue();

double getMaxPredictedValue();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;

import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.inference.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;

import java.util.Map;

public class BoundedWindowInferenceModel implements BoundedInferenceModel {
public static final double DEFAULT_MIN_PREDICTED_VALUE = 0;

private final BoundedInferenceModel model;
private final double minPredictedValue;
private final double maxPredictedValue;
private final double adjustmentValue;

public BoundedWindowInferenceModel(BoundedInferenceModel model) {
this(model, model.getMinPredictedValue(), model.getMaxPredictedValue());
}

public BoundedWindowInferenceModel(BoundedInferenceModel model, double minPredictedValue, double maxPredictedValue) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The absolute minimum of 0 is only for LTR models but we can imagine situations were we want to scale the result of a regression model in [-1; 1].

I would personally remove the adjusment from the constructor cause it does not make sense:

public BoundedWindowInferenceModel(BoundedInferenceModel model, double minPredictedValue, double maxPredictedValue) {
        this.model = model;
        this.minPredictedValue = minPredictedValue;
        this.maxPredictedValue = maxPredictedValue;
}

Then you can create a static method (scaleLtrModel ?) that would do the following:

public scaleLtrModel(BoundedInferenceModel model) {
  int adjustment =  LTR_MIN_PREDICTED_VALUE - model.getMinPredictedValue();
  return new BoundedWindowInferenceModel(model, model.getMinPredictedValue() + adjustment, model.getMaxPredictedValue() + adjustment)
}

Then you can use the formula of the POC to scale the prediction:

double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();
// First we scale the data to [0 ,1]
predictedValue = (predictedValue - model.getMinPredictedValue()) / (model.getMaxPredictedValue() - model.getMinPredictedValue());

// Then we scale the data to the desired interval
predictedValue = predictedValue * (getMaxPredictedValue() - getMinPredictedValue()) + getMinPredictedValue();

Also I would rename the class into something like MinMaxScaledInferenceModel

WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The absolute minimum of 0 is only for LTR models but we can imagine situations were we want to scale the result of a regression model in [-1; 1].

I thought we had decided with @jimczi that we would not perform scaling, but rather slide the scores up to ensure they are all positive (only if the minimum score was negative). Correct?

The absolute minimum of 0 is only for LTR models but we can imagine situations were we want to scale the result of a regression model in [-1; 1].

That makes sense. Is there a need for this now to fix this bug though? I was under the impression that the bug was only about the negative scores being returned and us having to deal with that if true.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I proposed "sliding" the score because we can apply it on the minimum and maximum value for a model entirely (instead of per query). This means that scores will still be comparable between queries.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @jimczi - what are you thoughts on the PR as-is then?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Let's use the new minPredictedValue and maxPredictedValue from the BoundedInferenceModel directly, no need to make it configurable for now.

Copy link
Contributor Author

@markjhoy markjhoy Apr 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@afoucret -

Then you can use the formula of the POC to scale the prediction:
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue(); // First we scale the data to [0 ,1] predictedValue = (predictedValue - model.getMinPredictedValue()) / (model.getMaxPredictedValue() - model.getMinPredictedValue());

Looking at this here -- would scaling in this method (normalizing it to 0 -> 1.0 first) fall victim to what we're trying to avoid here - that is, doing this may compress the space and lose precision for the scores and might cause closely scored items to have equal rank?

I think it's a good idea overall, and certainly more flexible for the developer though... so, I'm on the fence about it. cc: @jimczi

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@afoucret - the more I think about this, for the time being, let's go with this solution, and if we decide to add in your suggestion, let's do that afterwords so we can unblock your work.

this.model = model;
this.minPredictedValue = minPredictedValue;
this.maxPredictedValue = maxPredictedValue;

if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) {
this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue;
} else {
this.adjustmentValue = 0.0;
}
}

@Override
public String[] getFeatureNames() {
return model.getFeatureNames();
}

@Override
public TargetType targetType() {
return model.targetType();
}

@Override
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap));
}

@Override
public InferenceResults infer(double[] features, InferenceConfig config) {
return boundInferenceResultScores(model.infer(features, config));
}

@Override
public boolean supportsFeatureImportance() {
return model.supportsFeatureImportance();
}

@Override
public String getName() {
return "bounded_window[" + model.getName() + "]";
}

@Override
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
model.rewriteFeatureIndices(newFeatureIndexMapping);
}

@Override
public long ramBytesUsed() {
return model.ramBytesUsed();
}

@Override
public double getMinPredictedValue() {
return minPredictedValue;
}

@Override
public double getMaxPredictedValue() {
return maxPredictedValue;
}

private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) {
// if the min value < the default minimum, slide the values up by the adjustment value
if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) {
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();

predictedValue += this.adjustmentValue;

return new RegressionInferenceResults(
predictedValue,
inferenceResult.getResultsField(),
((RegressionInferenceResults) inferenceResult).getFeatureImportance()
);
}

throw new IllegalStateException(
LoggerMessageFormat.format(
"Model used within a {} should return a {} but got {} instead",
BoundedWindowInferenceModel.class.getSimpleName(),
RegressionInferenceResults.class.getSimpleName(),
inferenceResult.getClass().getSimpleName()
)
);
}

@Override
public String toString() {
return "BoundedWindowInferenceModel{"
+ "model="
+ model
+ ", minPredictedValue="
+ getMinPredictedValue()
+ ", maxPredictedValue="
+ getMaxPredictedValue()
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.CachedSupplier;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.inference.InferenceResults;
Expand All @@ -36,6 +37,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -52,7 +54,7 @@
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;

public class EnsembleInferenceModel implements InferenceModel {
public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel {

public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
Expand Down Expand Up @@ -97,6 +99,7 @@ public static EnsembleInferenceModel fromXContent(XContentParser parser) {
private final List<String> classificationLabels;
private final double[] classificationWeights;
private volatile boolean preparedForInference = false;
private final Supplier<double[]> predictedValuesBoundariesSupplier;

private EnsembleInferenceModel(
List<InferenceModel> models,
Expand All @@ -112,6 +115,7 @@ private EnsembleInferenceModel(
this.classificationWeights = classificationWeights == null
? null
: classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
}

@Override
Expand Down Expand Up @@ -328,21 +332,57 @@ public double[] getClassificationWeights() {

@Override
public String toString() {
return "EnsembleInferenceModel{"
+ "featureNames="
+ Arrays.toString(featureNames)
+ ", models="
+ models
+ ", outputAggregator="
+ outputAggregator
+ ", targetType="
+ targetType
+ ", classificationLabels="
+ classificationLabels
+ ", classificationWeights="
+ Arrays.toString(classificationWeights)
+ ", preparedForInference="
+ preparedForInference
+ '}';
StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");

builder.append("featureNames=")
.append(Arrays.toString(featureNames))
.append(", models=")
.append(models)
.append(", outputAggregator=")
.append(outputAggregator)
.append(", targetType=")
.append(targetType);

if (targetType == TargetType.CLASSIFICATION) {
builder.append(", classificationLabels=")
.append(classificationLabels)
.append(", classificationWeights=")
.append(Arrays.toString(classificationWeights));
} else if (targetType == TargetType.REGRESSION) {
builder.append(", minPredictedValue=")
.append(getMinPredictedValue())
.append(", maxPredictedValue=")
.append(getMaxPredictedValue());
}

builder.append(", preparedForInference=").append(preparedForInference);

return builder.append('}').toString();
}

@Override
public double getMinPredictedValue() {
return this.predictedValuesBoundariesSupplier.get()[0];
}

@Override
public double getMaxPredictedValue() {
return this.predictedValuesBoundariesSupplier.get()[1];
}

private double[] initModelBoundaries() {
double[] modelsMinBoundaries = new double[models.size()];
double[] modelsMaxBoundaries = new double[models.size()];
int i = 0;
for (InferenceModel model : models) {
if (model instanceof BoundedInferenceModel boundedInferenceModel) {
modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
} else {
throw new IllegalStateException("All submodels have to be bounded");
}
}

return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) };
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

Expand Down Expand Up @@ -79,13 +80,21 @@ private void preProcess(Map<String, Object> fields) {

public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
preProcess(fields);

InferenceModel inferenceModel = trainedModel;

if (config instanceof LearningToRankConfig) {
assert trainedModel instanceof BoundedInferenceModel;
inferenceModel = new BoundedWindowInferenceModel((BoundedInferenceModel) trainedModel);
}

if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
throw ExceptionsHelper.badRequestException(
"Feature importance is not supported for the configured model of type [{}]",
trainedModel.getName()
);
}
return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
}

public TargetType getTargetType() {
Expand Down
Loading
Loading