-
Notifications
You must be signed in to change notification settings - Fork 25.6k
Add Bounded Window to Inference Models for Rescoring to Ensure Positive Score Range #125694
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
markjhoy
merged 12 commits into
elastic:main
from
markjhoy:markjhoy/fix_ltr_rescore_retriever_bug
Apr 2, 2025
Merged
Changes from 8 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
920b6b5
apply bounded window inference model
markjhoy bde0af8
linting
markjhoy 6270943
add unit tests
markjhoy 661ca2b
[CI] Auto commit changes from spotless
5e6ab73
add additional tests
markjhoy 55a1989
Merge branch 'main' into markjhoy/fix_ltr_rescore_retriever_bug
markjhoy ab02bce
Merge branch 'main' into markjhoy/fix_ltr_rescore_retriever_bug
markjhoy aa1b7da
Merge branch 'main' into markjhoy/fix_ltr_rescore_retriever_bug
markjhoy d82f4d2
Merge branch 'main' into markjhoy/fix_ltr_rescore_retriever_bug
markjhoy ce2d87e
Merge branch 'main' into markjhoy/fix_ltr_rescore_retriever_bug
markjhoy 1dd7230
Merge branch 'main' into markjhoy/fix_ltr_rescore_retriever_bug
markjhoy 3d97708
remove unused constructor
markjhoy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 125694 | ||
summary: LTR score bounding | ||
area: Ranking | ||
type: bug | ||
issues: [] |
14 changes: 14 additions & 0 deletions
14
...g/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedInferenceModel.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; | ||
|
||
public interface BoundedInferenceModel extends InferenceModel { | ||
double getMinPredictedValue(); | ||
|
||
double getMaxPredictedValue(); | ||
} |
127 changes: 127 additions & 0 deletions
127
...ticsearch/xpack/core/ml/inference/trainedmodel/inference/BoundedWindowInferenceModel.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference; | ||
|
||
import org.elasticsearch.common.logging.LoggerMessageFormat; | ||
import org.elasticsearch.inference.InferenceResults; | ||
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; | ||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; | ||
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; | ||
|
||
import java.util.Map; | ||
|
||
public class BoundedWindowInferenceModel implements BoundedInferenceModel { | ||
public static final double DEFAULT_MIN_PREDICTED_VALUE = 0; | ||
|
||
private final BoundedInferenceModel model; | ||
private final double minPredictedValue; | ||
private final double maxPredictedValue; | ||
private final double adjustmentValue; | ||
|
||
public BoundedWindowInferenceModel(BoundedInferenceModel model) { | ||
this(model, model.getMinPredictedValue(), model.getMaxPredictedValue()); | ||
} | ||
|
||
public BoundedWindowInferenceModel(BoundedInferenceModel model, double minPredictedValue, double maxPredictedValue) { | ||
this.model = model; | ||
this.minPredictedValue = minPredictedValue; | ||
this.maxPredictedValue = maxPredictedValue; | ||
|
||
if (this.minPredictedValue < DEFAULT_MIN_PREDICTED_VALUE) { | ||
this.adjustmentValue = DEFAULT_MIN_PREDICTED_VALUE - this.minPredictedValue; | ||
} else { | ||
this.adjustmentValue = 0.0; | ||
} | ||
} | ||
|
||
@Override | ||
public String[] getFeatureNames() { | ||
return model.getFeatureNames(); | ||
} | ||
|
||
@Override | ||
public TargetType targetType() { | ||
return model.targetType(); | ||
} | ||
|
||
@Override | ||
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config, Map<String, String> featureDecoderMap) { | ||
return boundInferenceResultScores(model.infer(fields, config, featureDecoderMap)); | ||
} | ||
|
||
@Override | ||
public InferenceResults infer(double[] features, InferenceConfig config) { | ||
return boundInferenceResultScores(model.infer(features, config)); | ||
} | ||
|
||
@Override | ||
public boolean supportsFeatureImportance() { | ||
return model.supportsFeatureImportance(); | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return "bounded_window[" + model.getName() + "]"; | ||
} | ||
|
||
@Override | ||
public void rewriteFeatureIndices(Map<String, Integer> newFeatureIndexMapping) { | ||
model.rewriteFeatureIndices(newFeatureIndexMapping); | ||
} | ||
|
||
@Override | ||
public long ramBytesUsed() { | ||
return model.ramBytesUsed(); | ||
} | ||
|
||
@Override | ||
public double getMinPredictedValue() { | ||
return minPredictedValue; | ||
} | ||
|
||
@Override | ||
public double getMaxPredictedValue() { | ||
return maxPredictedValue; | ||
} | ||
|
||
private InferenceResults boundInferenceResultScores(InferenceResults inferenceResult) { | ||
// if the min value < the default minimum, slide the values up by the adjustment value | ||
if (inferenceResult instanceof RegressionInferenceResults regressionInferenceResults) { | ||
double predictedValue = ((Number) regressionInferenceResults.predictedValue()).doubleValue(); | ||
|
||
predictedValue += this.adjustmentValue; | ||
|
||
return new RegressionInferenceResults( | ||
predictedValue, | ||
inferenceResult.getResultsField(), | ||
((RegressionInferenceResults) inferenceResult).getFeatureImportance() | ||
); | ||
} | ||
|
||
throw new IllegalStateException( | ||
LoggerMessageFormat.format( | ||
"Model used within a {} should return a {} but got {} instead", | ||
BoundedWindowInferenceModel.class.getSimpleName(), | ||
RegressionInferenceResults.class.getSimpleName(), | ||
inferenceResult.getClass().getSimpleName() | ||
) | ||
); | ||
} | ||
|
||
@Override | ||
public String toString() { | ||
return "BoundedWindowInferenceModel{" | ||
+ "model=" | ||
+ model | ||
+ ", minPredictedValue=" | ||
+ getMinPredictedValue() | ||
+ ", maxPredictedValue=" | ||
+ getMaxPredictedValue() | ||
+ '}'; | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The absolute minimum of
0
is only for LTR models but we can imagine situations were we want to scale the result of a regression model in [-1; 1].I would personally remove the adjusment from the constructor cause it does not make sense:
Then you can create a static method (
scaleLtrModel
?) that would do the following:Then you can use the formula of the POC to scale the prediction:
Also I would rename the class into something like
MinMaxScaledInferenceModel
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we had decided with @jimczi that we would not perform scaling, but rather slide the scores up to ensure they are all positive (only if the minimum score was negative). Correct?
That makes sense. Is there a need for this now to fix this bug though? I was under the impression that the bug was only about the negative scores being returned and us having to deal with that if true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I proposed "sliding" the score because we can apply it on the minimum and maximum value for a model entirely (instead of per query). This means that scores will still be comparable between queries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @jimczi - what are you thoughts on the PR as-is then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Let's use the new
minPredictedValue
andmaxPredictedValue
from theBoundedInferenceModel
directly, no need to make it configurable for now.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@afoucret -
Looking at this here -- would scaling in this method (normalizing it to 0 -> 1.0 first) fall victim to what we're trying to avoid here - that is, doing this may compress the space and lose precision for the scores and might cause closely scored items to have equal rank?
I think it's a good idea overall, and certainly more flexible for the developer though... so, I'm on the fence about it. cc: @jimczi
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@afoucret - the more I think about this, for the time being, let's go with this solution, and if we decide to add in your suggestion, let's do that afterwords so we can unblock your work.