Skip to content

Commit d6d323a

Browse files
markjhoyelasticsearchmachine
andauthored
Add Bounded Window to Inference Models for Rescoring to Ensure Positive Score Range (#125694) (#127345)
* apply bounded window inference model * linting * add unit tests * [CI] Auto commit changes from spotless * add additional tests * remove unused constructor --------- Co-authored-by: elasticsearchmachine <[email protected]> (cherry picked from commit e77bf80)
1 parent d4dc01c commit d6d323a

File tree

10 files changed

+471
-48
lines changed

10 files changed

+471
-48
lines changed

docs/changelog/125694.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125694
2+
summary: LTR score bounding
3+
area: Ranking
4+
type: bug
5+
issues: []
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
9+
10+
public interface BoundedInferenceModel extends InferenceModel {
11+
double getMinPredictedValue();
12+
13+
double getMaxPredictedValue();
14+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
9+
10+
import org.elasticsearch.common.logging.LoggerMessageFormat;
11+
import org.elasticsearch.inference.InferenceResults;
12+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
13+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
14+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
15+
16+
import java.util.Map;
17+
18+
public class BoundedWindowInferenceModel implements BoundedInferenceModel {
19+
public static final double DEFAULT_MIN_PREDICTED_VALUE = 0;
20+
21+
private final BoundedInferenceModel model;
22+
private final double minPredictedValue;
23+
private final double maxPredictedValue;
24+
private final double adjustmentValue;
25+
26+
public BoundedWindowInferenceModel(BoundedInferenceModel model) {
27+
this.model = model;
28+
this.minPredictedValue = model.getMinPredictedValue();
29+
this.maxPredictedValue = model.getMaxPredictedValue();
30+
31+
if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) {
32+
this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue;
33+
} else {
34+
this.adjustmentValue = 0.0;
35+
}
36+
}
37+
38+
@Override
39+
public String[] getFeatureNames() {
40+
return model.getFeatureNames();
41+
}
42+
43+
@Override
44+
public TargetType targetType() {
45+
return model.targetType();
46+
}
47+
48+
@Override
49+
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
50+
return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap));
51+
}
52+
53+
@Override
54+
public InferenceResults infer(double[] features, InferenceConfig config) {
55+
return boundInferenceResultScores(model.infer(features, config));
56+
}
57+
58+
@Override
59+
public boolean supportsFeatureImportance() {
60+
return model.supportsFeatureImportance();
61+
}
62+
63+
@Override
64+
public String getName() {
65+
return "bounded_window[" + model.getName() + "]";
66+
}
67+
68+
@Override
69+
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
70+
model.rewriteFeatureIndices(newFeatureIndexMapping);
71+
}
72+
73+
@Override
74+
public long ramBytesUsed() {
75+
return model.ramBytesUsed();
76+
}
77+
78+
@Override
79+
public double getMinPredictedValue() {
80+
return minPredictedValue;
81+
}
82+
83+
@Override
84+
public double getMaxPredictedValue() {
85+
return maxPredictedValue;
86+
}
87+
88+
private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) {
89+
// if the min value < the default minimum, slide the values up by the adjustment value
90+
if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) {
91+
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();
92+
93+
predictedValue += this.adjustmentValue;
94+
95+
return new RegressionInferenceResults(
96+
predictedValue,
97+
inferenceResult.getResultsField(),
98+
((RegressionInferenceResults) inferenceResult).getFeatureImportance()
99+
);
100+
}
101+
102+
throw new IllegalStateException(
103+
LoggerMessageFormat.format(
104+
"Model used within a {} should return a {} but got {} instead",
105+
BoundedWindowInferenceModel.class.getSimpleName(),
106+
RegressionInferenceResults.class.getSimpleName(),
107+
inferenceResult.getClass().getSimpleName()
108+
)
109+
);
110+
}
111+
112+
@Override
113+
public String toString() {
114+
return "BoundedWindowInferenceModel{"
115+
+ "model="
116+
+ model
117+
+ ", minPredictedValue="
118+
+ getMinPredictedValue()
119+
+ ", maxPredictedValue="
120+
+ getMaxPredictedValue()
121+
+ '}';
122+
}
123+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.RamUsageEstimator;
1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.util.CachedSupplier;
1415
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.core.Tuple;
1617
import org.elasticsearch.inference.InferenceResults;
@@ -36,6 +37,7 @@
3637
import java.util.List;
3738
import java.util.Map;
3839
import java.util.Set;
40+
import java.util.function.Supplier;
3941
import java.util.stream.Collectors;
4042
import java.util.stream.IntStream;
4143

@@ -52,7 +54,7 @@
5254
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
5355
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
5456

55-
public class EnsembleInferenceModel implements InferenceModel {
57+
public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel {
5658

5759
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
5860
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
@@ -97,6 +99,7 @@ public static EnsembleInferenceModel fromXContent(XContentParser parser) {
9799
private final List<String> classificationLabels;
98100
private final double[] classificationWeights;
99101
private volatile boolean preparedForInference = false;
102+
private final Supplier<double[]> predictedValuesBoundariesSupplier;
100103

101104
private EnsembleInferenceModel(
102105
List<InferenceModel> models,
@@ -112,6 +115,7 @@ private EnsembleInferenceModel(
112115
this.classificationWeights = classificationWeights == null
113116
? null
114117
: classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
118+
this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
115119
}
116120

117121
@Override
@@ -328,21 +332,57 @@ public double[] getClassificationWeights() {
328332

329333
@Override
330334
public String toString() {
331-
return "EnsembleInferenceModel{"
332-
+ "featureNames="
333-
+ Arrays.toString(featureNames)
334-
+ ", models="
335-
+ models
336-
+ ", outputAggregator="
337-
+ outputAggregator
338-
+ ", targetType="
339-
+ targetType
340-
+ ", classificationLabels="
341-
+ classificationLabels
342-
+ ", classificationWeights="
343-
+ Arrays.toString(classificationWeights)
344-
+ ", preparedForInference="
345-
+ preparedForInference
346-
+ '}';
335+
StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");
336+
337+
builder.append("featureNames=")
338+
.append(Arrays.toString(featureNames))
339+
.append(", models=")
340+
.append(models)
341+
.append(", outputAggregator=")
342+
.append(outputAggregator)
343+
.append(", targetType=")
344+
.append(targetType);
345+
346+
if (targetType == TargetType.CLASSIFICATION) {
347+
builder.append(", classificationLabels=")
348+
.append(classificationLabels)
349+
.append(", classificationWeights=")
350+
.append(Arrays.toString(classificationWeights));
351+
} else if (targetType == TargetType.REGRESSION) {
352+
builder.append(", minPredictedValue=")
353+
.append(getMinPredictedValue())
354+
.append(", maxPredictedValue=")
355+
.append(getMaxPredictedValue());
356+
}
357+
358+
builder.append(", preparedForInference=").append(preparedForInference);
359+
360+
return builder.append('}').toString();
361+
}
362+
363+
@Override
364+
public double getMinPredictedValue() {
365+
return this.predictedValuesBoundariesSupplier.get()[0];
366+
}
367+
368+
@Override
369+
public double getMaxPredictedValue() {
370+
return this.predictedValuesBoundariesSupplier.get()[1];
371+
}
372+
373+
private double[] initModelBoundaries() {
374+
double[] modelsMinBoundaries = new double[models.size()];
375+
double[] modelsMaxBoundaries = new double[models.size()];
376+
int i = 0;
377+
for (InferenceModel model : models) {
378+
if (model instanceof BoundedInferenceModel boundedInferenceModel) {
379+
modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
380+
modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
381+
} else {
382+
throw new IllegalStateException("All submodels have to be bounded");
383+
}
384+
}
385+
386+
return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) };
347387
}
348388
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
1515
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
1616
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
17+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
1718
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
1819
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1920

@@ -79,13 +80,21 @@ private void preProcess(Map<String, Object> fields) {
7980

8081
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
8182
preProcess(fields);
83+
84+
InferenceModel inferenceModel = trainedModel;
85+
86+
if (config instanceof LearningToRankConfig) {
87+
assert trainedModel instanceof BoundedInferenceModel;
88+
inferenceModel = new BoundedWindowInferenceModel((BoundedInferenceModel) trainedModel);
89+
}
90+
8291
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
8392
throw ExceptionsHelper.badRequestException(
8493
"Feature importance is not supported for the configured model of type [{}]",
8594
trainedModel.getName()
8695
);
8796
}
88-
return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
97+
return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
8998
}
9099

91100
public TargetType getTargetType() {

0 commit comments

Comments
 (0)