diff --git a/docs/changelog/124255.yaml b/docs/changelog/124255.yaml new file mode 100644 index 0000000000000..5b9d829b0dcba --- /dev/null +++ b/docs/changelog/124255.yaml @@ -0,0 +1,5 @@ +pr: 124255 +summary: LTR score normalization +area: Ranking +type: bug +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java new file mode 100644 index 0000000000000..f6516ec8c446f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; + +public interface BoundedInferenceModel extends InferenceModel { + double getMinPredictedValue(); + + double getMaxPredictedValue(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index c0e8610a357b0..20677b16c6a61 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InferenceResults; @@ -36,6 +37,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -52,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS; -public class EnsembleInferenceModel implements InferenceModel { +public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel { public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class); private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class); @@ -97,6 +99,7 @@ public static EnsembleInferenceModel fromXContent(XContentParser parser) { private final List classificationLabels; private final double[] classificationWeights; private volatile boolean preparedForInference = false; + private final Supplier predictedValuesBoundariesSupplier; private EnsembleInferenceModel( List models, @@ -112,6 +115,7 @@ private EnsembleInferenceModel( this.classificationWeights = classificationWeights == null ? null : classificationWeights.stream().mapToDouble(Double::doubleValue).toArray(); + this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries); } @Override @@ -328,21 +332,57 @@ public double[] getClassificationWeights() { @Override public String toString() { - return "EnsembleInferenceModel{" - + "featureNames=" - + Arrays.toString(featureNames) - + ", models=" - + models - + ", outputAggregator=" - + outputAggregator - + ", targetType=" - + targetType - + ", classificationLabels=" - + classificationLabels - + ", classificationWeights=" - + Arrays.toString(classificationWeights) - + ", preparedForInference=" - + preparedForInference - + '}'; + StringBuilder builder = new StringBuilder("EnsembleInferenceModel{"); + + builder.append("featureNames=") + .append(Arrays.toString(featureNames)) + .append(", models=") + .append(models) + .append(", outputAggregator=") + .append(outputAggregator) + .append(", targetType=") + .append(targetType); + + if (targetType == TargetType.CLASSIFICATION) { + builder.append(", classificationLabels=") + .append(classificationLabels) + .append(", classificationWeights=") + .append(Arrays.toString(classificationWeights)); + } else if (targetType == TargetType.REGRESSION) { + builder.append(", minPredictedValue=") + .append(getMinPredictedValue()) + .append(", maxPredictedValue=") + .append(getMaxPredictedValue()); + } + + builder.append(", preparedForInference=").append(preparedForInference); + + return builder.append('}').toString(); + } + + @Override + public double getMinPredictedValue() { + return this.predictedValuesBoundariesSupplier.get()[0]; + } + + @Override + public double getMaxPredictedValue() { + return this.predictedValuesBoundariesSupplier.get()[1]; + } + + private double[] initModelBoundaries() { + double[] modelsMinBoundaries = new double[models.size()]; + double[] modelsMaxBoundaries = new double[models.size()]; + int i = 0; + for (InferenceModel model : models) { + if (model instanceof BoundedInferenceModel boundedInferenceModel) { + modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue(); + modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue(); + } else { + throw new IllegalStateException("All submodels have to be bounded"); + } + } + + return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) }; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index 636cbcac725f4..1b0f92b92beaa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -79,13 +80,22 @@ private void preProcess(Map fields) { public InferenceResults infer(Map fields, InferenceConfig config) { preProcess(fields); - if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) { + + InferenceModel inferenceModel = trainedModel; + + if (config instanceof LearningToRankConfig) { + assert trainedModel instanceof BoundedInferenceModel; + inferenceModel = new ScaledInferenceModel(BoundedInferenceModel.class.cast(trainedModel)); + } + + if (config.requestingImportance() && inferenceModel.supportsFeatureImportance() == false) { throw ExceptionsHelper.badRequestException( "Feature importance is not supported for the configured model of type [{}]", - trainedModel.getName() + inferenceModel.getName() ); } - return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap()); + + return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap()); } public TargetType getTargetType() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/ScaledInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/ScaledInferenceModel.java new file mode 100644 index 0000000000000..21c8a1cbeac3e --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/ScaledInferenceModel.java @@ -0,0 +1,125 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; + +import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.util.Map; + +public class ScaledInferenceModel implements BoundedInferenceModel { + + public static final double DEFAULT_MIN_PREDICTED_VALUE = 0; + public static final double DEFAULT_MAX_PREDICTED_VALUE = 1; + + private final BoundedInferenceModel model; + private final double minPredictedValue; + private final double maxPredictedValue; + + public ScaledInferenceModel(BoundedInferenceModel model) { + this(model, DEFAULT_MIN_PREDICTED_VALUE, DEFAULT_MAX_PREDICTED_VALUE); + } + + public ScaledInferenceModel(BoundedInferenceModel model, double minPredictedValue, double maxPredictedValue) { + this.model = model; + this.minPredictedValue = minPredictedValue; + this.maxPredictedValue = maxPredictedValue; + } + + @Override + public String[] getFeatureNames() { + return model.getFeatureNames(); + } + + @Override + public TargetType targetType() { + return model.targetType(); + } + + @Override + public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) { + return scaleInferenceResult(model.infer(fields, config, featureDecoderMap)); + } + + @Override + public InferenceResults infer(double[] features, InferenceConfig config) { + return scaleInferenceResult(model.infer(features, config)); + } + + @Override + public boolean supportsFeatureImportance() { + return model.supportsFeatureImportance(); + } + + @Override + public String getName() { + return "scaled[" + model.getName() + "]"; + } + + @Override + public void rewriteFeatureIndices(Map newFeatureIndexMapping) { + model.rewriteFeatureIndices(newFeatureIndexMapping); + } + + @Override + public long ramBytesUsed() { + return model.ramBytesUsed(); + } + + @Override + public double getMinPredictedValue() { + return minPredictedValue; + } + + @Override + public double getMaxPredictedValue() { + return maxPredictedValue; + } + + private InferenceResults scaleInferenceResult(InferenceResults inferenceResult) { + if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) { + double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue(); + // First we scale the data to [0 ,1] + predictedValue = (predictedValue - model.getMinPredictedValue()) / (model.getMaxPredictedValue() - model + .getMinPredictedValue()); + + // Then we scale the data to the desired interval + predictedValue = predictedValue * (getMaxPredictedValue() - getMinPredictedValue()) + getMinPredictedValue(); + + return new RegressionInferenceResults( + predictedValue, + inferenceResult.getResultsField(), + ((RegressionInferenceResults) inferenceResult).getFeatureImportance() + ); + } + + throw new IllegalStateException( + LoggerMessageFormat.format( + "Model used within a {} should return a {} but got {} instead", + ScaledInferenceModel.class.getSimpleName(), + RegressionInferenceResults.class.getSimpleName(), + inferenceResult.getClass().getSimpleName() + ) + ); + } + + @Override + public String toString() { + return "ScaledInferenceModel{" + + "model=" + + model + + ", minPredictedValue=" + + getMinPredictedValue() + + ", maxPredictedValue=" + + getMaxPredictedValue() + + '}'; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index d23600383b34b..ca7ca6ffa2d82 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -58,7 +58,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD; -public class TreeInferenceModel implements InferenceModel { +public class TreeInferenceModel implements InferenceModel, BoundedInferenceModel { private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class); public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class); @@ -90,7 +90,7 @@ public static TreeInferenceModel fromXContent(XContentParser parser) { private String[] featureNames; private final TargetType targetType; private List classificationLabels; - private final double highOrderCategory; + private final double[] leafBoundaries; private final int maxDepth; private final int leafSize; private volatile boolean preparedForInference = false; @@ -108,7 +108,7 @@ public static TreeInferenceModel fromXContent(XContentParser parser) { this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new); this.targetType = targetType == null ? TargetType.REGRESSION : targetType; this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); - this.highOrderCategory = maxLeafValue(); + this.leafBoundaries = leafBoundaries(); int leafSize = 1; for (Node node : this.nodes) { if (node instanceof LeafNode leafNode) { @@ -218,7 +218,7 @@ private double[] classificationProbability(double[] inferenceValue) { } // If we are classification, we should assume that the inference return value is whole. assert inferenceValue[0] == Math.rint(inferenceValue[0]); - double maxCategory = this.highOrderCategory; + double maxCategory = highOrderCategory(); // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0) @@ -229,6 +229,10 @@ private double[] classificationProbability(double[] inferenceValue) { return list; } + private double highOrderCategory() { + return getMaxPredictedValue(); + } + private double[] getLeaf(double[] features) { Node node = nodes[0]; while (node.isLeaf() == false) { @@ -237,6 +241,16 @@ private double[] getLeaf(double[] features) { return ((LeafNode) node).leafValue; } + @Override + public double getMinPredictedValue() { + return leafBoundaries[0]; + } + + @Override + public double getMaxPredictedValue() { + return leafBoundaries[1]; + } + public double[][] featureImportance(double[] fieldValues) { double[][] featureImportance = new double[fieldValues.length][leafSize]; for (int i = 0; i < fieldValues.length; i++) { @@ -366,21 +380,21 @@ public long ramBytesUsed() { return size; } - private double maxLeafValue() { - if (targetType != TargetType.CLASSIFICATION) { - return Double.NaN; - } - double max = 0.0; + private double[] leafBoundaries() { + double[] bounds = new double[] { Double.MAX_VALUE, Double.MIN_VALUE }; + for (Node node : this.nodes) { if (node instanceof LeafNode leafNode) { if (leafNode.leafValue.length > 1) { - return leafNode.leafValue.length; + return new double[] { 0, leafNode.leafValue.length }; } else { - max = Math.max(leafNode.leafValue[0], max); + bounds[0] = Math.min(leafNode.leafValue[0], bounds[0]); + bounds[1] = Math.max(leafNode.leafValue[0], bounds[1]); } } } - return max; + + return bounds; } public Node[] getNodes() { @@ -389,24 +403,35 @@ public Node[] getNodes() { @Override public String toString() { - return "TreeInferenceModel{" - + "nodes=" - + Arrays.toString(nodes) - + ", featureNames=" - + Arrays.toString(featureNames) - + ", targetType=" - + targetType - + ", classificationLabels=" - + classificationLabels - + ", highOrderCategory=" - + highOrderCategory - + ", maxDepth=" - + maxDepth - + ", leafSize=" - + leafSize - + ", preparedForInference=" - + preparedForInference - + '}'; + StringBuilder builder = new StringBuilder("TreeInferenceModel{"); + + builder.append("nodes=") + .append(Arrays.toString(nodes)) + .append(", featureNames=") + .append(Arrays.toString(featureNames)) + .append(", targetType=") + .append(targetType); + + if (targetType == TargetType.CLASSIFICATION) { + builder.append(", classificationLabels=") + .append(classificationLabels) + .append(", highOrderCategory=") + .append(highOrderCategory()); + } else if (targetType == TargetType.REGRESSION) { + builder.append(", minPredictedValue=") + .append(getMinPredictedValue()) + .append(", maxPredictedValue=") + .append(getMaxPredictedValue()); + } + + builder.append(", maxDepth=") + .append(maxDepth) + .append(", leafSize=") + .append(leafSize) + .append(", preparedForInference=") + .append(preparedForInference); + + return builder.append('}').toString(); } private static int getDepth(Node[] nodes, int nodeIndex, int depth) {