Skip to content

Commit dfe61f9

Browse files
authored
[LTR] Fix feature display order when using explain. (#137671) (#137684)
1 parent b8de8ba commit dfe61f9

File tree

3 files changed

+32
-27
lines changed

3 files changed

+32
-27
lines changed

docs/changelog/137671.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 137671
2+
summary: "[LTR] Fix feature display order when using explain"
3+
area: Search
4+
type: bug
5+
issues: []

x-pack/plugin/ml/qa/single-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/LearningToRankRescorerIT.java

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public void testLearningToRankRescoreWithExplain() throws Exception {
8383
}
8484
}""");
8585
var response = client().performRequest(request);
86-
assertExplainExtractedFeatures(response, List.of("type_tv", "cost", "two"));
86+
assertExplainExtractedFeatures(response, List.of("cost", "type_tv", "two"));
8787
}
8888

8989
public void testLearningToRankRescoreSmallWindow() throws Exception {
@@ -192,27 +192,25 @@ private static void assertExplainExtractedFeatures(Response response, List<Strin
192192
assertThat(queryDetails.get(1).get("description"), equalTo("extracted features"));
193193

194194
var featureDetails = new ArrayList<>((ArrayList<Map<String, Object>>) queryDetails.get(1).get("details"));
195-
assertThat(featureDetails.size(), equalTo(3));
196-
197-
var missingKeys = new ArrayList<String>();
198-
for (String expectedFeature : expectedFeatures) {
199-
var expectedDescription = Strings.format("feature value for [%s]", expectedFeature);
200-
201-
var wasFound = false;
202-
for (Map<String, Object> detailItem : featureDetails) {
203-
if (detailItem.get("description").equals(expectedDescription)) {
204-
featureDetails.remove(detailItem);
205-
wasFound = true;
206-
break;
207-
}
208-
}
209-
210-
if (wasFound == false) {
211-
missingKeys.add(expectedFeature);
195+
assertThat(featureDetails.size(), equalTo(expectedFeatures.size()));
196+
197+
// Extract feature names in the order they appear in the explanation
198+
List<String> actualFeatureOrder = new ArrayList<>();
199+
for (Map<String, Object> detailItem : featureDetails) {
200+
String description = (String) detailItem.get("description");
201+
// Extract feature name from "feature value for [featureName]"
202+
if (description != null && description.startsWith("feature value for [") && description.endsWith("]")) {
203+
String featureName = description.substring("feature value for [".length(), description.length() - 1);
204+
actualFeatureOrder.add(featureName);
212205
}
213206
}
214207

215-
assertThat(Strings.format("Could not find features: [%s]", String.join(", ", missingKeys)), featureDetails.size(), equalTo(0));
208+
// Verify that features appear in the expected order
209+
assertThat(
210+
"Features should appear in the expected order. Expected: " + expectedFeatures + ", Actual: " + actualFeatureOrder,
211+
actualFeatureOrder,
212+
equalTo(expectedFeatures)
213+
);
216214
}
217215
}
218216

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import java.util.Comparator;
2929
import java.util.List;
3030
import java.util.Map;
31-
import java.util.Objects;
3231
import java.util.Set;
3332

3433
import static java.util.stream.Collectors.toUnmodifiableSet;
@@ -164,21 +163,24 @@ public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreCon
164163
List<FeatureExtractor> featureExtractors = ltrContext.buildFeatureExtractors(searcher);
165164
int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum();
166165

167-
Map<String, Object> features = Maps.newMapWithExpectedSize(featureSize);
166+
Map<String, Object> extractedFeatures = Maps.newMapWithExpectedSize(featureSize);
168167

169168
for (FeatureExtractor featureExtractor : featureExtractors) {
170169
featureExtractor.setNextReader(currentSegment);
171-
featureExtractor.addFeatures(features, targetDoc);
170+
featureExtractor.addFeatures(extractedFeatures, targetDoc);
172171
}
173172

174173
// Predicting the value
175-
var ltrScore = ((Number) localModelDefinition.inferLtr(features, ltrContext.learningToRankConfig).predictedValue()).floatValue();
174+
var ltrScore = ((Number) localModelDefinition.inferLtr(extractedFeatures, ltrContext.learningToRankConfig).predictedValue())
175+
.floatValue();
176176

177177
List<Explanation> featureExplanations = new ArrayList<>();
178-
for (String featureName : features.keySet()) {
179-
Number featureValue = Objects.requireNonNullElse((Number) features.get(featureName), 0);
180-
featureExplanations.add(Explanation.match(featureValue, "feature value for [" + featureName + "]"));
181-
}
178+
ltrContext.learningToRankConfig.getFeatureExtractorBuilders().forEach(featureExtractor -> {
179+
String featureName = featureExtractor.featureName();
180+
if (extractedFeatures.containsKey(featureName) && extractedFeatures.get(featureName) instanceof Number featureValue) {
181+
featureExplanations.add(Explanation.match(featureValue, "feature value for [" + featureName + "]"));
182+
}
183+
});
182184

183185
return Explanation.match(
184186
ltrScore,

0 commit comments

Comments
 (0)