diff --git a/docs/changelog/125694.yaml b/docs/changelog/125694.yaml new file mode 100644 index 0000000000000..c4c7a622dbdf6 --- /dev/null +++ b/docs/changelog/125694.yaml @@ -0,0 +1,5 @@ +pr: 125694 +summary: LTR score bounding +area: Ranking +type: bug +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java new file mode 100644 index 0000000000000..f6516ec8c446f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java @@ -0,0 +1,14 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; + +public interface BoundedInferenceModel extends InferenceModel { + double getMinPredictedValue(); + + double getMaxPredictedValue(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModel.java new file mode 100644 index 0000000000000..0acdc0f6691cc --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModel.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; + +import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; + +import java.util.Map; + +public class BoundedWindowInferenceModel implements BoundedInferenceModel { + public static final double DEFAULT_MIN_PREDICTED_VALUE = 0; + + private final BoundedInferenceModel model; + private final double minPredictedValue; + private final double maxPredictedValue; + private final double adjustmentValue; + + public BoundedWindowInferenceModel(BoundedInferenceModel model) { + this.model = model; + this.minPredictedValue = model.getMinPredictedValue(); + this.maxPredictedValue = model.getMaxPredictedValue(); + + if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) { + this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue; + } else { + this.adjustmentValue = 0.0; + } + } + + @Override + public String[] getFeatureNames() { + return model.getFeatureNames(); + } + + @Override + public TargetType targetType() { + return model.targetType(); + } + + @Override + public InferenceResults infer(Map fields, InferenceConfig config, Map featureDecoderMap) { + return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap)); + } + + @Override + public InferenceResults infer(double[] features, InferenceConfig config) { + return boundInferenceResultScores(model.infer(features, config)); + } + + @Override + public boolean supportsFeatureImportance() { + return model.supportsFeatureImportance(); + } + + @Override + public String getName() { + return "bounded_window[" + model.getName() + "]"; + } + + @Override + public void rewriteFeatureIndices(Map newFeatureIndexMapping) { + model.rewriteFeatureIndices(newFeatureIndexMapping); + } + + @Override + public long ramBytesUsed() { + return model.ramBytesUsed(); + } + + @Override + public double getMinPredictedValue() { + return minPredictedValue; + } + + @Override + public double getMaxPredictedValue() { + return maxPredictedValue; + } + + private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) { + // if the min value < the default minimum, slide the values up by the adjustment value + if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) { + double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue(); + + predictedValue += this.adjustmentValue; + + return new RegressionInferenceResults( + predictedValue, + inferenceResult.getResultsField(), + ((RegressionInferenceResults) inferenceResult).getFeatureImportance() + ); + } + + throw new IllegalStateException( + LoggerMessageFormat.format( + "Model used within a {} should return a {} but got {} instead", + BoundedWindowInferenceModel.class.getSimpleName(), + RegressionInferenceResults.class.getSimpleName(), + inferenceResult.getClass().getSimpleName() + ) + ); + } + + @Override + public String toString() { + return "BoundedWindowInferenceModel{" + + "model=" + + model + + ", minPredictedValue=" + + getMinPredictedValue() + + ", maxPredictedValue=" + + getMaxPredictedValue() + + '}'; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java index c0e8610a357b0..20677b16c6a61 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.util.RamUsageEstimator; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.CachedSupplier; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InferenceResults; @@ -36,6 +37,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -52,7 +54,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.TRAINED_MODELS; -public class EnsembleInferenceModel implements InferenceModel { +public class EnsembleInferenceModel implements InferenceModel, BoundedInferenceModel { public static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(EnsembleInferenceModel.class); private static final Logger LOGGER = LogManager.getLogger(EnsembleInferenceModel.class); @@ -97,6 +99,7 @@ public static EnsembleInferenceModel fromXContent(XContentParser parser) { private final List classificationLabels; private final double[] classificationWeights; private volatile boolean preparedForInference = false; + private final Supplier predictedValuesBoundariesSupplier; private EnsembleInferenceModel( List models, @@ -112,6 +115,7 @@ private EnsembleInferenceModel( this.classificationWeights = classificationWeights == null ? null : classificationWeights.stream().mapToDouble(Double::doubleValue).toArray(); + this.predictedValuesBoundariesSupplier = CachedSupplier.wrap(this::initModelBoundaries); } @Override @@ -328,21 +332,57 @@ public double[] getClassificationWeights() { @Override public String toString() { - return "EnsembleInferenceModel{" - + "featureNames=" - + Arrays.toString(featureNames) - + ", models=" - + models - + ", outputAggregator=" - + outputAggregator - + ", targetType=" - + targetType - + ", classificationLabels=" - + classificationLabels - + ", classificationWeights=" - + Arrays.toString(classificationWeights) - + ", preparedForInference=" - + preparedForInference - + '}'; + StringBuilder builder = new StringBuilder("EnsembleInferenceModel{"); + + builder.append("featureNames=") + .append(Arrays.toString(featureNames)) + .append(", models=") + .append(models) + .append(", outputAggregator=") + .append(outputAggregator) + .append(", targetType=") + .append(targetType); + + if (targetType == TargetType.CLASSIFICATION) { + builder.append(", classificationLabels=") + .append(classificationLabels) + .append(", classificationWeights=") + .append(Arrays.toString(classificationWeights)); + } else if (targetType == TargetType.REGRESSION) { + builder.append(", minPredictedValue=") + .append(getMinPredictedValue()) + .append(", maxPredictedValue=") + .append(getMaxPredictedValue()); + } + + builder.append(", preparedForInference=").append(preparedForInference); + + return builder.append('}').toString(); + } + + @Override + public double getMinPredictedValue() { + return this.predictedValuesBoundariesSupplier.get()[0]; + } + + @Override + public double getMaxPredictedValue() { + return this.predictedValuesBoundariesSupplier.get()[1]; + } + + private double[] initModelBoundaries() { + double[] modelsMinBoundaries = new double[models.size()]; + double[] modelsMaxBoundaries = new double[models.size()]; + int i = 0; + for (InferenceModel model : models) { + if (model instanceof BoundedInferenceModel boundedInferenceModel) { + modelsMinBoundaries[i] = boundedInferenceModel.getMinPredictedValue(); + modelsMaxBoundaries[i++] = boundedInferenceModel.getMaxPredictedValue(); + } else { + throw new IllegalStateException("All submodels have to be bounded"); + } + } + + return new double[] { outputAggregator.aggregate(modelsMinBoundaries), outputAggregator.aggregate(modelsMaxBoundaries) }; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java index 636cbcac725f4..c17bfecd6daf7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinition.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -79,13 +80,21 @@ private void preProcess(Map fields) { public InferenceResults infer(Map fields, InferenceConfig config) { preProcess(fields); + + InferenceModel inferenceModel = trainedModel; + + if (config instanceof LearningToRankConfig) { + assert trainedModel instanceof BoundedInferenceModel; + inferenceModel = new BoundedWindowInferenceModel((BoundedInferenceModel) trainedModel); + } + if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) { throw ExceptionsHelper.badRequestException( "Feature importance is not supported for the configured model of type [{}]", trainedModel.getName() ); } - return trainedModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap()); + return inferenceModel.infer(fields, config, config.requestingImportance() ? getDecoderMap() : Collections.emptyMap()); } public TargetType getTargetType() { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java index d23600383b34b..744a7defda337 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java @@ -58,7 +58,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.SPLIT_FEATURE; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode.THRESHOLD; -public class TreeInferenceModel implements InferenceModel { +public class TreeInferenceModel implements InferenceModel, BoundedInferenceModel { private static final Logger LOGGER = LogManager.getLogger(TreeInferenceModel.class); public static final long SHALLOW_SIZE = shallowSizeOfInstance(TreeInferenceModel.class); @@ -90,7 +90,7 @@ public static TreeInferenceModel fromXContent(XContentParser parser) { private String[] featureNames; private final TargetType targetType; private List classificationLabels; - private final double highOrderCategory; + private final double[] leafBoundaries; private final int maxDepth; private final int leafSize; private volatile boolean preparedForInference = false; @@ -108,7 +108,7 @@ public static TreeInferenceModel fromXContent(XContentParser parser) { this.nodes = nodes.stream().map(NodeBuilder::build).toArray(Node[]::new); this.targetType = targetType == null ? TargetType.REGRESSION : targetType; this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); - this.highOrderCategory = maxLeafValue(); + this.leafBoundaries = getLeafBoundaries(); int leafSize = 1; for (Node node : this.nodes) { if (node instanceof LeafNode leafNode) { @@ -218,7 +218,7 @@ private double[] classificationProbability(double[] inferenceValue) { } // If we are classification, we should assume that the inference return value is whole. assert inferenceValue[0] == Math.rint(inferenceValue[0]); - double maxCategory = this.highOrderCategory; + double maxCategory = getHighOrderCategory(); // If we are classification, we should assume that the largest leaf value is whole. assert maxCategory == Math.rint(maxCategory); double[] list = Collections.nCopies(Double.valueOf(maxCategory + 1).intValue(), 0.0) @@ -366,21 +366,20 @@ public long ramBytesUsed() { return size; } - private double maxLeafValue() { - if (targetType != TargetType.CLASSIFICATION) { - return Double.NaN; - } - double max = 0.0; + private double[] getLeafBoundaries() { + double[] bounds = new double[] { Double.MAX_VALUE, Double.MIN_VALUE }; + for (Node node : this.nodes) { if (node instanceof LeafNode leafNode) { if (leafNode.leafValue.length > 1) { - return leafNode.leafValue.length; + return new double[] { 0, leafNode.leafValue.length }; } else { - max = Math.max(leafNode.leafValue[0], max); + bounds[0] = Math.min(leafNode.leafValue[0], bounds[0]); + bounds[1] = Math.max(leafNode.leafValue[0], bounds[1]); } } } - return max; + return bounds; } public Node[] getNodes() { @@ -389,24 +388,35 @@ public Node[] getNodes() { @Override public String toString() { - return "TreeInferenceModel{" - + "nodes=" - + Arrays.toString(nodes) - + ", featureNames=" - + Arrays.toString(featureNames) - + ", targetType=" - + targetType - + ", classificationLabels=" - + classificationLabels - + ", highOrderCategory=" - + highOrderCategory - + ", maxDepth=" - + maxDepth - + ", leafSize=" - + leafSize - + ", preparedForInference=" - + preparedForInference - + '}'; + StringBuilder builder = new StringBuilder("TreeInferenceModel{"); + + builder.append("nodes=") + .append(Arrays.toString(nodes)) + .append(", featureNames=") + .append(Arrays.toString(featureNames)) + .append(", targetType=") + .append(targetType); + + if (targetType == TargetType.CLASSIFICATION) { + builder.append(", classificationLabels=") + .append(classificationLabels) + .append(", highOrderCategory=") + .append(getHighOrderCategory()); + } else if (targetType == TargetType.REGRESSION) { + builder.append(", minPredictedValue=") + .append(getMinPredictedValue()) + .append(", maxPredictedValue=") + .append(getMaxPredictedValue()); + } + + builder.append(", maxDepth=") + .append(maxDepth) + .append(", leafSize=") + .append(leafSize) + .append(", preparedForInference=") + .append(preparedForInference); + + return builder.append('}').toString(); } private static int getDepth(Node[] nodes, int nodeIndex, int depth) { @@ -420,6 +430,20 @@ private static int getDepth(Node[] nodes, int nodeIndex, int depth) { return Math.max(depthLeft, depthRight) + 1; } + @Override + public double getMinPredictedValue() { + return leafBoundaries[0]; + } + + @Override + public double getMaxPredictedValue() { + return leafBoundaries[1]; + } + + private double getHighOrderCategory() { + return getMaxPredictedValue(); + } + static class NodeBuilder { private static final ObjectParser PARSER = new ObjectParser<>( diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModelTests.java new file mode 100644 index 0000000000000..b22f85eb2ecdb --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModelTests.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.inference.results.SingleValueInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel; +import static org.hamcrest.Matchers.equalTo; + +public class BoundedWindowInferenceModelTests extends ESTestCase { + + private static final List featureNames = Arrays.asList("foo", "bar"); + + public void testBoundsSetting() throws IOException { + BoundedWindowInferenceModel testModel = getModel(-2.0, 5.2, 10.5); + assertThat(testModel.getMinPredictedValue(), equalTo(-2.0)); + assertThat(testModel.getMaxPredictedValue(), equalTo(10.5)); + } + + public void testInferenceScoresWithoutAdjustment() throws IOException { + BoundedWindowInferenceModel testModel = getModel(1.0, 5.2, 10.5); + + List featureVector = Arrays.asList(0.4, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + Double lowResultValue = ((SingleValueInferenceResults) testModel.infer( + featureMap, + RegressionConfig.EMPTY_PARAMS, + Collections.emptyMap() + )).value(); + assertThat(lowResultValue, equalTo(1.0)); + + featureVector = Arrays.asList(12.0, 0.0); + featureMap = zipObjMap(featureNames, featureVector); + Double highResultValue = ((SingleValueInferenceResults) testModel.infer( + featureMap, + RegressionConfig.EMPTY_PARAMS, + Collections.emptyMap() + )).value(); + assertThat(highResultValue, equalTo(10.5)); + + double[] featureArray = new double[2]; + featureArray[0] = 12.0; + featureArray[1] = 0.0; + Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS)) + .value(); + assertThat(highResultValueFromFeatures, equalTo(10.5)); + } + + public void testInferenceScoresWithAdjustment() throws IOException { + BoundedWindowInferenceModel testModel = getModel(-5.0, 1.2, 6.5); + + List featureVector = Arrays.asList(-10.0, 0.0); + Map featureMap = zipObjMap(featureNames, featureVector); + Double lowResultValue = ((SingleValueInferenceResults) testModel.infer( + featureMap, + RegressionConfig.EMPTY_PARAMS, + Collections.emptyMap() + )).value(); + assertThat(lowResultValue, equalTo(0.0)); + + featureVector = Arrays.asList(12.0, 0.0); + featureMap = zipObjMap(featureNames, featureVector); + Double highResultValue = ((SingleValueInferenceResults) testModel.infer( + featureMap, + RegressionConfig.EMPTY_PARAMS, + Collections.emptyMap() + )).value(); + assertThat(highResultValue, equalTo(11.5)); + + double[] featureArray = new double[2]; + featureArray[0] = 12.0; + featureArray[1] = 0.0; + Double highResultValueFromFeatures = ((SingleValueInferenceResults) testModel.infer(featureArray, RegressionConfig.EMPTY_PARAMS)) + .value(); + assertThat(highResultValueFromFeatures, equalTo(11.5)); + } + + private BoundedWindowInferenceModel getModel(double lowerBoundValue, double midValue, double upperBoundValue) throws IOException { + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); + TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.getRightChild(), upperBoundValue); + TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); + builder.addLeaf(leftChildNode.getLeftChild(), lowerBoundValue); + builder.addLeaf(leftChildNode.getRightChild(), midValue); + + List featureNames = Arrays.asList("foo", "bar"); + Tree treeObject = builder.setFeatureNames(featureNames).build(); + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); + tree.rewriteFeatureIndices(Collections.emptyMap()); + + return new BoundedWindowInferenceModel(tree); + } + + private static Map zipObjMap(List keys, List values) { + return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java index b3fe5d2cebc6b..02e95e3d1786c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModelTests.java @@ -39,6 +39,7 @@ import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel; import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -537,6 +538,40 @@ public void testFeatureImportance() throws IOException { assertThat(featureImportance[1][0], closeTo(0.1451914, eps)); } + public void testMinAndMaxBoundaries() throws IOException { + List featureNames = Arrays.asList("foo", "bar"); + Tree tree1 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(0.3)) + .addNode(TreeNode.builder(2).setThreshold(0.8).setSplitFeature(1).setLeftChild(3).setRightChild(4)) + .addNode(TreeNode.builder(3).setLeafValue(0.1)) + .addNode(TreeNode.builder(4).setLeafValue(0.2)) + .build(); + Tree tree2 = Tree.builder() + .setFeatureNames(featureNames) + .setRoot(TreeNode.builder(0).setLeftChild(1).setRightChild(2).setSplitFeature(0).setThreshold(0.5)) + .addNode(TreeNode.builder(1).setLeafValue(1.5)) + .addNode(TreeNode.builder(2).setLeafValue(0.9)) + .build(); + Ensemble ensembleObject = Ensemble.builder() + .setTargetType(TargetType.REGRESSION) + .setFeatureNames(featureNames) + .setTrainedModels(Arrays.asList(tree1, tree2)) + .setOutputAggregator(new WeightedSum(new double[] { 0.5, 0.5 })) + .build(); + + EnsembleInferenceModel ensemble = deserializeFromTrainedModel( + ensembleObject, + xContentRegistry(), + EnsembleInferenceModel::fromXContent + ); + ensemble.rewriteFeatureIndices(Collections.emptyMap()); + + assertThat(ensemble.getMinPredictedValue(), equalTo(1.0)); + assertThat(ensemble.getMaxPredictedValue(), equalTo(1.8)); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java index 9f2326d022eab..65789af110a8c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java @@ -13,6 +13,8 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -25,17 +27,26 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilderTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode; import java.io.IOException; import java.text.ParseException; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.ENSEMBLE_MODEL; import static org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests.TREE_MODEL; +import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModelTestUtils.deserializeFromTrainedModel; import static org.hamcrest.Matchers.closeTo; import static org.hamcrest.Matchers.equalTo; @@ -176,6 +187,35 @@ public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws I } } + public void testWithLearningToRankConfiguration() throws IOException { + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); + TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.getRightChild(), -2.0); + TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); + builder.addLeaf(leftChildNode.getLeftChild(), 0.2); + builder.addLeaf(leftChildNode.getRightChild(), 1.5); + + List featureNames = Arrays.asList("foo", "bar"); + Tree treeObject = builder.setFeatureNames(featureNames).build(); + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); + tree.rewriteFeatureIndices(Collections.emptyMap()); + + BoundedWindowInferenceModel testModel = new BoundedWindowInferenceModel(tree); + + InferenceDefinition definition = new InferenceDefinition(testModel, null); + LearningToRankConfig config = new LearningToRankConfig( + randomBoolean() ? null : randomIntBetween(0, 10), + randomBoolean() + ? null + : Stream.generate(QueryExtractorBuilderTests::randomInstance).limit(randomInt(5)).collect(Collectors.toList()), + randomBoolean() ? null : randomMap(0, 10, () -> Tuple.tuple(randomIdentifier(), randomIdentifier())) + ); + + InferenceResults results = definition.infer(Map.of("foo", 1.0, "bar", 0.0), config); + + assertThat(results.predictedValue(), equalTo(2.0)); + } + public static String getClassificationDefinition(boolean customPreprocessor) { return Strings.format(""" { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java index ef5c7ad3542d1..e3f9dedf5a8b6 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModelTests.java @@ -284,6 +284,23 @@ public void testFeatureImportance() throws IOException { assertThat(featureImportance[1][0], closeTo(2.5, eps)); } + public void testMinAndMaxBoundaries() throws IOException { + Tree.Builder builder = Tree.builder().setTargetType(TargetType.REGRESSION); + TreeNode.Builder rootNode = builder.addJunction(0, 0, true, 0.5); + builder.addLeaf(rootNode.getRightChild(), 0.3); + TreeNode.Builder leftChildNode = builder.addJunction(rootNode.getLeftChild(), 1, true, 0.8); + builder.addLeaf(leftChildNode.getLeftChild(), 0.1); + builder.addLeaf(leftChildNode.getRightChild(), 0.2); + + List featureNames = Arrays.asList("foo", "bar"); + Tree treeObject = builder.setFeatureNames(featureNames).build(); + TreeInferenceModel tree = deserializeFromTrainedModel(treeObject, xContentRegistry(), TreeInferenceModel::fromXContent); + tree.rewriteFeatureIndices(Collections.emptyMap()); + + assertThat(tree.getMinPredictedValue(), equalTo(0.1)); + assertThat(tree.getMaxPredictedValue(), equalTo(0.3)); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); }