Skip to content

Commit 76ae142

Browse files
authored
[LTR] Fix feature display order when using explain. (elastic#137671) (elastic#137686)
1 parent 0d8befd commit 76ae142

File tree

3 files changed

+32
-27
lines changed

3 files changed

+32
-27
lines changed

docs/changelog/137671.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137671
2+
summary: "[LTR] Fix feature display order when using explain"
3+
area: Search
4+
type: bug
5+
issues: []

x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearningToRankRescorerIT.java

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public void testLearningToRankRescoreWithExplain() throws Exception {
8383
}
8484
}""");
8585
var response = client().performRequest(request);
86-
assertExplainExtractedFeatures(response, List.of("type_tv", "cost", "two"));
86+
assertExplainExtractedFeatures(response, List.of("cost", "type_tv", "two"));
8787
}
8888

8989
public void testLearningToRankRescoreSmallWindow() throws Exception {
@@ -192,27 +192,25 @@ private static void assertExplainExtractedFeatures(Response response, List<Strin
192192
assertThat(queryDetails.get(1).get("description"), equalTo("extracted features"));
193193

194194
var featureDetails = new ArrayList<>((ArrayList<Map<String, Object>>) queryDetails.get(1).get("details"));
195-
assertThat(featureDetails.size(), equalTo(3));
196-
197-
var missingKeys = new ArrayList<String>();
198-
for (String expectedFeature : expectedFeatures) {
199-
var expectedDescription = Strings.format("feature value for [%s]", expectedFeature);
200-
201-
var wasFound = false;
202-
for (Map<String, Object> detailItem : featureDetails) {
203-
if (detailItem.get("description").equals(expectedDescription)) {
204-
featureDetails.remove(detailItem);
205-
wasFound = true;
206-
break;
207-
}
208-
}
209-
210-
if (wasFound == false) {
211-
missingKeys.add(expectedFeature);
195+
assertThat(featureDetails.size(), equalTo(expectedFeatures.size()));
196+
197+
// Extract feature names in the order they appear in the explanation
198+
List<String> actualFeatureOrder = new ArrayList<>();
199+
for (Map<String, Object> detailItem : featureDetails) {
200+
String description = (String) detailItem.get("description");
201+
// Extract feature name from "feature value for [featureName]"
202+
if (description != null && description.startsWith("feature value for [") && description.endsWith("]")) {
203+
String featureName = description.substring("feature value for [".length(), description.length() - 1);
204+
actualFeatureOrder.add(featureName);
212205
}
213206
}
214207

215-
assertThat(Strings.format("Could not find features: [%s]", String.join(", ", missingKeys)), featureDetails.size(), equalTo(0));
208+
// Verify that features appear in the expected order
209+
assertThat(
210+
"Features should appear in the expected order. Expected: " + expectedFeatures + ", Actual: " + actualFeatureOrder,
211+
actualFeatureOrder,
212+
equalTo(expectedFeatures)
213+
);
216214
}
217215
}
218216

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import java.util.Comparator;
2929
import java.util.List;
3030
import java.util.Map;
31-
import java.util.Objects;
3231
import java.util.Set;
3332

3433
import static java.util.stream.Collectors.toUnmodifiableSet;
@@ -169,21 +168,24 @@ public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreCon
169168
List<FeatureExtractor> featureExtractors = ltrContext.buildFeatureExtractors(searcher);
170169
int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum();
171170

172-
Map<String, Object> features = Maps.newMapWithExpectedSize(featureSize);
171+
Map<String, Object> extractedFeatures = Maps.newMapWithExpectedSize(featureSize);
173172

174173
for (FeatureExtractor featureExtractor : featureExtractors) {
175174
featureExtractor.setNextReader(currentSegment);
176-
featureExtractor.addFeatures(features, targetDoc);
175+
featureExtractor.addFeatures(extractedFeatures, targetDoc);
177176
}
178177

179178
// Predicting the value
180-
var ltrScore = ((Number) localModelDefinition.inferLtr(features, ltrContext.learningToRankConfig).predictedValue()).floatValue();
179+
var ltrScore = ((Number) localModelDefinition.inferLtr(extractedFeatures, ltrContext.learningToRankConfig).predictedValue())
180+
.floatValue();
181181

182182
List<Explanation> featureExplanations = new ArrayList<>();
183-
for (String featureName : features.keySet()) {
184-
Number featureValue = Objects.requireNonNullElse((Number) features.get(featureName), 0);
185-
featureExplanations.add(Explanation.match(featureValue, "feature value for [" + featureName + "]"));
186-
}
183+
ltrContext.learningToRankConfig.getFeatureExtractorBuilders().forEach(featureExtractor -> {
184+
String featureName = featureExtractor.featureName();
185+
if (extractedFeatures.containsKey(featureName) && extractedFeatures.get(featureName) instanceof Number featureValue) {
186+
featureExplanations.add(Explanation.match(featureValue, "feature value for [" + featureName + "]"));
187+
}
188+
});
187189

188190
return Explanation.match(
189191
ltrScore,

0 commit comments

Comments
 (0)