Skip to content

Commit 1b58c0d

Browse files
committed
LTR score normalization POC.
1 parent 67af069 commit 1b58c0d

File tree

5 files changed

+257
-50
lines changed

5 files changed

+257
-50
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
9+
10+
public interface BoundedInferenceModel extends InferenceModel {
11+
double getMinPredictedValue();
12+
double getMaxPredictedValue();
13+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

Lines changed: 57 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.RamUsageEstimator;
1313
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.common.util.CachedSupplier;
1415
import org.elasticsearch.core.Nullable;
1516
import org.elasticsearch.core.Tuple;
1617
import org.elasticsearch.inference.InferenceResults;
@@ -36,6 +37,7 @@
3637
import java.util.List;
3738
import java.util.Map;
3839
import java.util.Set;
40+
import java.util.function.Supplier;
3941
import java.util.stream.Collectors;
4042
import java.util.stream.IntStream;
4143

@@ -52,7 +54,7 @@
5254
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
5355
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS;
5456

55-
public class EnsembleInferenceModel implements InferenceModel {
57+
public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel {
5658

5759
public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class);
5860
private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class);
@@ -97,6 +99,7 @@ public static EnsembleInferenceModel fromXContent(XContentParser parser) {
9799
private final List<String> classificationLabels;
98100
private final double[] classificationWeights;
99101
private volatile boolean preparedForInference = false;
102+
private final Supplier<double[]> predictedValuesBoundariesSupplier;
100103

101104
private EnsembleInferenceModel(
102105
List<InferenceModel> models,
@@ -112,6 +115,7 @@ private EnsembleInferenceModel(
112115
this.classificationWeights = classificationWeights == null
113116
? null
114117
: classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
118+
this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries);
115119
}
116120

117121
@Override
@@ -328,21 +332,57 @@ public double[] getClassificationWeights() {
328332

329333
@Override
330334
public String toString() {
331-
return "EnsembleInferenceModel{"
332-
+ "featureNames="
333-
+ Arrays.toString(featureNames)
334-
+ ", models="
335-
+ models
336-
+ ", outputAggregator="
337-
+ outputAggregator
338-
+ ", targetType="
339-
+ targetType
340-
+ ", classificationLabels="
341-
+ classificationLabels
342-
+ ", classificationWeights="
343-
+ Arrays.toString(classificationWeights)
344-
+ ", preparedForInference="
345-
+ preparedForInference
346-
+ '}';
335+
StringBuilder builder = new StringBuilder("EnsembleInferenceModel{");
336+
337+
builder.append("featureNames=")
338+
.append(Arrays.toString(featureNames))
339+
.append(", models=")
340+
.append(models)
341+
.append(", outputAggregator=")
342+
.append(outputAggregator)
343+
.append(", targetType=")
344+
.append(targetType);
345+
346+
if (targetType == TargetType.CLASSIFICATION) {
347+
builder.append(", classificationLabels=")
348+
.append(classificationLabels)
349+
.append(", classificationWeights=")
350+
.append(Arrays.toString(classificationWeights));
351+
} else if (targetType == TargetType.REGRESSION) {
352+
builder.append(", minPredictedValue=")
353+
.append(getMinPredictedValue())
354+
.append(", maxPredictedValue=")
355+
.append(getMaxPredictedValue());
356+
}
357+
358+
builder.append(", preparedForInference=").append(preparedForInference);
359+
360+
return builder.append('}').toString();
361+
}
362+
363+
@Override
364+
public double getMinPredictedValue() {
365+
return this.predictedValuesBoundariesSupplier.get()[0];
366+
}
367+
368+
@Override
369+
public double getMaxPredictedValue() {
370+
return this.predictedValuesBoundariesSupplier.get()[1];
371+
}
372+
373+
private double[] initModelBoundaries() {
374+
double[] modelsMinBoundaries = new double[models.size()];
375+
double[] modelsMaxBoundaries = new double[models.size()];
376+
int i = 0;
377+
for (InferenceModel model : models) {
378+
if (model instanceof BoundedInferenceModel boundedInferenceModel) {
379+
modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue();
380+
modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue();
381+
} else {
382+
throw new IllegalStateException("All submodels have to be bounded");
383+
}
384+
}
385+
386+
return new double[] {outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries)};
347387
}
348388
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
1515
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
1616
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
17+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
1718
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
1819
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1920

@@ -79,13 +80,22 @@ private void preProcess(Map<String, Object> fields) {
7980

8081
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
8182
preProcess(fields);
82-
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
83+
84+
InferenceModel inferenceModel = trainedModel;
85+
86+
if (config instanceof LearningToRankConfig) {
87+
assert trainedModel instanceof BoundedInferenceModel;
88+
inferenceModel = new ScaledInferenceModel(BoundedInferenceModel.class.cast(trainedModel));
89+
}
90+
91+
if (config.requestingImportance() && inferenceModel.supportsFeatureImportance() == false) {
8392
throw ExceptionsHelper.badRequestException(
8493
"Feature importance is not supported for the configured model of type [{}]",
85-
trainedModel.getName()
94+
inferenceModel.getName()
8695
);
8796
}
88-
return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
97+
98+
return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
8999
}
90100

91101
public TargetType getTargetType() {
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference;
9+
10+
import org.elasticsearch.common.logging.LoggerMessageFormat;
11+
import org.elasticsearch.inference.InferenceResults;
12+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
13+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
14+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
15+
16+
import java.util.Map;
17+
18+
public class ScaledInferenceModel implements BoundedInferenceModel {
19+
20+
public static final double DEFAULT_MIN_PREDICTED_VALUE = 0;
21+
public static final double DEFAULT_MAX_PREDICTED_VALUE = 1;
22+
23+
private final BoundedInferenceModel model;
24+
private final double minPredictedValue;
25+
private final double maxPredictedValue;
26+
27+
public ScaledInferenceModel(BoundedInferenceModel model) {
28+
this(model, DEFAULT_MIN_PREDICTED_VALUE, DEFAULT_MAX_PREDICTED_VALUE);
29+
}
30+
31+
public ScaledInferenceModel(BoundedInferenceModel model, double minPredictedValue, double maxPredictedValue) {
32+
this.model = model;
33+
this.minPredictedValue = minPredictedValue;
34+
this.maxPredictedValue = maxPredictedValue;
35+
}
36+
37+
@Override
38+
public String[] getFeatureNames() {
39+
return model.getFeatureNames();
40+
}
41+
42+
@Override
43+
public TargetType targetType() {
44+
return model.targetType();
45+
}
46+
47+
@Override
48+
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) {
49+
return scaleInferenceResult(model.infer(fields, config, featureDecoderMap));
50+
}
51+
52+
@Override
53+
public InferenceResults infer(double[] features, InferenceConfig config) {
54+
return scaleInferenceResult(model.infer(features, config));
55+
}
56+
57+
@Override
58+
public boolean supportsFeatureImportance() {
59+
return model.supportsFeatureImportance();
60+
}
61+
62+
@Override
63+
public String getName() {
64+
return "scaled[" + model.getName() + "]";
65+
}
66+
67+
@Override
68+
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) {
69+
model.rewriteFeatureIndices(newFeatureIndexMapping);
70+
}
71+
72+
@Override
73+
public long ramBytesUsed() {
74+
return model.ramBytesUsed();
75+
}
76+
77+
@Override
78+
public double getMinPredictedValue() {
79+
return minPredictedValue;
80+
}
81+
82+
@Override
83+
public double getMaxPredictedValue() {
84+
return maxPredictedValue;
85+
}
86+
87+
private InferenceResults scaleInferenceResult(InferenceResults inferenceResult) {
88+
if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) {
89+
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue();
90+
// First we scale the data to [0 ,1]
91+
predictedValue = (predictedValue - model.getMinPredictedValue()) / (model.getMaxPredictedValue() - model.getMinPredictedValue());
92+
93+
// Then we scale the data to the desired interval
94+
predictedValue = predictedValue * (getMaxPredictedValue() - getMinPredictedValue()) + getMinPredictedValue();
95+
96+
return new RegressionInferenceResults(predictedValue, inferenceResult.getResultsField(), ((RegressionInferenceResults) inferenceResult).getFeatureImportance());
97+
}
98+
99+
throw new IllegalStateException(
100+
LoggerMessageFormat.format(
101+
"Model used within a {} should return a {} but got {} instead",
102+
ScaledInferenceModel.class.getSimpleName(),
103+
RegressionInferenceResults.class.getSimpleName(),
104+
inferenceResult.getClass().getSimpleName()
105+
));
106+
}
107+
108+
@Override
109+
public String toString() {
110+
return "ScaledInferenceModel{"
111+
+ "model="
112+
+ model
113+
+ ", minPredictedValue="
114+
+ getMinPredictedValue()
115+
+ ", maxPredictedValue="
116+
+ getMaxPredictedValue()
117+
+ '}';
118+
}
119+
}

0 commit comments

Comments
 (0)