Skip to content

Commit 856a4d7

Browse files
authored
LTR - Fix explain failure when index has multiple shards (#120717)
1 parent 66db8c7 commit 856a4d7

File tree

4 files changed

+269
-24
lines changed

4 files changed

+269
-24
lines changed

docs/changelog/120717.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 120717
2+
summary: Fix LTR rescorer throws 'local model reference is null' on multi-shards index when explained is enabled
3+
area: Ranking
4+
type: bug
5+
issues:
6+
- 120739

server/src/main/java/org/elasticsearch/search/SearchService.java

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -957,32 +957,38 @@ public void executeFetchPhase(ShardFetchRequest request, SearchShardTask task, A
957957
final ReaderContext readerContext = findReaderContext(request.contextId(), request);
958958
final ShardSearchRequest shardSearchRequest = readerContext.getShardSearchRequest(request.getShardSearchRequest());
959959
final Releasable markAsUsed = readerContext.markAsUsed(getKeepAlive(shardSearchRequest));
960-
runAsync(getExecutor(readerContext.indexShard()), () -> {
961-
try (SearchContext searchContext = createContext(readerContext, shardSearchRequest, task, ResultsType.FETCH, false)) {
962-
if (request.lastEmittedDoc() != null) {
963-
searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc();
964-
}
965-
searchContext.assignRescoreDocIds(readerContext.getRescoreDocIds(request.getRescoreDocIds()));
966-
searchContext.searcher().setAggregatedDfs(readerContext.getAggregatedDfs(request.getAggregatedDfs()));
967-
try (
968-
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(searchContext, true, System.nanoTime())
969-
) {
970-
fetchPhase.execute(searchContext, request.docIds(), request.getRankDocks());
971-
if (readerContext.singleSession()) {
972-
freeReaderContext(request.contextId());
960+
rewriteAndFetchShardRequest(readerContext.indexShard(), shardSearchRequest, listener.delegateFailure((l, rewritten) -> {
961+
runAsync(getExecutor(readerContext.indexShard()), () -> {
962+
try (SearchContext searchContext = createContext(readerContext, rewritten, task, ResultsType.FETCH, false)) {
963+
if (request.lastEmittedDoc() != null) {
964+
searchContext.scrollContext().lastEmittedDoc = request.lastEmittedDoc();
973965
}
974-
executor.success();
966+
searchContext.assignRescoreDocIds(readerContext.getRescoreDocIds(request.getRescoreDocIds()));
967+
searchContext.searcher().setAggregatedDfs(readerContext.getAggregatedDfs(request.getAggregatedDfs()));
968+
try (
969+
SearchOperationListenerExecutor executor = new SearchOperationListenerExecutor(
970+
searchContext,
971+
true,
972+
System.nanoTime()
973+
)
974+
) {
975+
fetchPhase.execute(searchContext, request.docIds(), request.getRankDocks());
976+
if (readerContext.singleSession()) {
977+
freeReaderContext(request.contextId());
978+
}
979+
executor.success();
980+
}
981+
var fetchResult = searchContext.fetchResult();
982+
// inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block
983+
fetchResult.incRef();
984+
return fetchResult;
985+
} catch (Exception e) {
986+
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
987+
// we handle the failure in the failure listener below
988+
throw e;
975989
}
976-
var fetchResult = searchContext.fetchResult();
977-
// inc-ref fetch result because we close the SearchContext that references it in this try-with-resources block
978-
fetchResult.incRef();
979-
return fetchResult;
980-
} catch (Exception e) {
981-
assert TransportActions.isShardNotAvailableException(e) == false : new AssertionError(e);
982-
// we handle the failure in the failure listener below
983-
throw e;
984-
}
985-
}, wrapFailureListener(listener, readerContext, markAsUsed));
990+
}, wrapFailureListener(l, readerContext, markAsUsed));
991+
}));
986992
}
987993

988994
protected void checkCancelled(SearchShardTask task) {
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.integration;
9+
10+
import org.elasticsearch.action.bulk.BulkRequestBuilder;
11+
import org.elasticsearch.action.bulk.BulkResponse;
12+
import org.elasticsearch.action.index.IndexRequest;
13+
import org.elasticsearch.action.support.WriteRequest;
14+
import org.elasticsearch.cluster.metadata.IndexMetadata;
15+
import org.elasticsearch.common.settings.Settings;
16+
import org.elasticsearch.core.Predicates;
17+
import org.elasticsearch.index.query.QueryBuilders;
18+
import org.elasticsearch.search.SearchHit;
19+
import org.elasticsearch.search.builder.SearchSourceBuilder;
20+
import org.elasticsearch.xcontent.XContentParser;
21+
import org.elasticsearch.xcontent.json.JsonXContent;
22+
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
23+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
24+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
25+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LearningToRankConfig;
26+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
27+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
28+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ltr.QueryExtractorBuilder;
29+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
30+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeNode;
31+
import org.elasticsearch.xpack.core.ml.job.config.Operator;
32+
import org.elasticsearch.xpack.core.ml.utils.QueryProvider;
33+
import org.elasticsearch.xpack.ml.support.BaseMlIntegTestCase;
34+
import org.junit.Before;
35+
36+
import java.io.IOException;
37+
import java.util.Collections;
38+
import java.util.List;
39+
40+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertResponse;
41+
import static org.hamcrest.Matchers.equalTo;
42+
import static org.hamcrest.Matchers.is;
43+
import static org.hamcrest.Matchers.notNullValue;
44+
45+
public class LearningToRankExplainIT extends BaseMlIntegTestCase {
46+
47+
private static final String LTR_SEARCH_INDEX = "ltr-search-index";
48+
private static final String LTR_MODEL = "ltr-model";
49+
private static final int NUMBER_OF_NODES = 3;
50+
private static final String DEFAULT_SEARCH_REQUEST_BODY = """
51+
{
52+
"query": {
53+
"match": { "product": { "query": "TV" } }
54+
},
55+
"rescore": {
56+
"window_size": 10,
57+
"learning_to_rank": {
58+
"model_id": "ltr-model",
59+
"params": { "keyword": "TV" }
60+
}
61+
}
62+
}""";
63+
64+
@Before
65+
public void setupCluster() throws IOException {
66+
internalCluster().ensureAtLeastNumDataNodes(NUMBER_OF_NODES);
67+
ensureStableCluster();
68+
createLtrModel();
69+
}
70+
71+
public void testLtrExplainWithSingleShard() throws IOException {
72+
runLtrExplainTest(1, 1, 2, new float[] { 15f, 11f });
73+
}
74+
75+
public void testLtrExplainWithMultipleShards() throws IOException {
76+
runLtrExplainTest(randomIntBetween(2, NUMBER_OF_NODES), 0, 2, new float[] { 15f, 11f });
77+
}
78+
79+
public void testLtrExplainWithReplicas() throws IOException {
80+
runLtrExplainTest(1, randomIntBetween(1, NUMBER_OF_NODES - 1), 2, new float[] { 15f, 11f });
81+
}
82+
83+
public void testLtrExplainWithMultipleShardsAndReplicas() throws IOException {
84+
runLtrExplainTest(randomIntBetween(2, NUMBER_OF_NODES), randomIntBetween(1, NUMBER_OF_NODES - 1), 2, new float[] { 15f, 11f });
85+
}
86+
87+
private void runLtrExplainTest(int numberOfShards, int numberOfReplicas, long expectedTotalHits, float[] expectedScores)
88+
throws IOException {
89+
createLtrIndex(numberOfShards, numberOfReplicas);
90+
try (XContentParser parser = createParser(JsonXContent.jsonXContent, DEFAULT_SEARCH_REQUEST_BODY)) {
91+
assertResponse(
92+
client().prepareSearch(LTR_SEARCH_INDEX)
93+
.setSource(new SearchSourceBuilder().parseXContent(parser, true, Predicates.always()))
94+
.setExplain(true),
95+
searchResponse -> {
96+
assertThat(searchResponse.getHits().getTotalHits().value(), equalTo(expectedTotalHits));
97+
for (int i = 0; i < expectedScores.length; i++) {
98+
// Check expected score
99+
SearchHit hit = searchResponse.getHits().getHits()[i];
100+
assertThat(hit.getScore(), equalTo(expectedScores[i]));
101+
102+
// Check explanation is present and contains the right data
103+
assertThat(hit.getExplanation(), notNullValue());
104+
assertThat(hit.getExplanation().getValue().floatValue(), equalTo(hit.getScore()));
105+
assertThat(hit.getExplanation().getDescription(), equalTo("rescored using LTR model ltr-model"));
106+
}
107+
}
108+
);
109+
}
110+
}
111+
112+
private void createLtrIndex(int numberOfShards, int numberOfReplicas) {
113+
client().admin()
114+
.indices()
115+
.prepareCreate(LTR_SEARCH_INDEX)
116+
.setSettings(
117+
Settings.builder()
118+
.put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, numberOfShards)
119+
.put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, numberOfReplicas)
120+
.build()
121+
)
122+
.setMapping("product", "type=keyword", "best_seller", "type=boolean")
123+
.get();
124+
125+
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
126+
IndexRequest indexRequest = new IndexRequest(LTR_SEARCH_INDEX);
127+
indexRequest.source("product", "TV", "best_seller", true);
128+
bulkRequestBuilder.add(indexRequest);
129+
130+
indexRequest = new IndexRequest(LTR_SEARCH_INDEX);
131+
indexRequest.source("product", "TV", "best_seller", false);
132+
bulkRequestBuilder.add(indexRequest);
133+
134+
indexRequest = new IndexRequest(LTR_SEARCH_INDEX);
135+
indexRequest.source("product", "VCR", "best_seller", true);
136+
bulkRequestBuilder.add(indexRequest);
137+
138+
indexRequest = new IndexRequest(LTR_SEARCH_INDEX);
139+
indexRequest.source("product", "VCR", "best_seller", true);
140+
bulkRequestBuilder.add(indexRequest);
141+
142+
indexRequest = new IndexRequest(LTR_SEARCH_INDEX);
143+
indexRequest.source("product", "Laptop", "best_seller", true);
144+
bulkRequestBuilder.add(indexRequest);
145+
146+
BulkResponse bulkResponse = bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE).get();
147+
assertThat(bulkResponse.hasFailures(), is(false));
148+
}
149+
150+
private void createLtrModel() throws IOException {
151+
client().execute(
152+
PutTrainedModelAction.INSTANCE,
153+
new PutTrainedModelAction.Request(
154+
TrainedModelConfig.builder()
155+
.setModelId(LTR_MODEL)
156+
.setInferenceConfig(
157+
LearningToRankConfig.builder(LearningToRankConfig.EMPTY_PARAMS)
158+
.setLearningToRankFeatureExtractorBuilders(
159+
List.of(
160+
new QueryExtractorBuilder(
161+
"best_seller",
162+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("best_seller", "true"))
163+
),
164+
new QueryExtractorBuilder(
165+
"product_match",
166+
QueryProvider.fromParsedQuery(QueryBuilders.termQuery("product", "{{keyword}}"))
167+
)
168+
)
169+
)
170+
.build()
171+
)
172+
.setParsedDefinition(
173+
new TrainedModelDefinition.Builder().setPreProcessors(Collections.emptyList())
174+
.setTrainedModel(
175+
Ensemble.builder()
176+
.setFeatureNames(List.of("best_seller", "product_bm25"))
177+
.setTargetType(TargetType.REGRESSION)
178+
.setTrainedModels(
179+
List.of(
180+
Tree.builder()
181+
.setFeatureNames(List.of("best_seller"))
182+
.setTargetType(TargetType.REGRESSION)
183+
.setRoot(
184+
TreeNode.builder(0)
185+
.setSplitFeature(0)
186+
.setSplitGain(12d)
187+
.setThreshold(1d)
188+
.setOperator(Operator.GTE)
189+
.setDefaultLeft(true)
190+
.setLeftChild(1)
191+
.setRightChild(2)
192+
)
193+
.addLeaf(1, 1)
194+
.addLeaf(2, 5)
195+
.build(),
196+
Tree.builder()
197+
.setFeatureNames(List.of("product_match"))
198+
.setTargetType(TargetType.REGRESSION)
199+
.setRoot(
200+
TreeNode.builder(0)
201+
.setSplitFeature(0)
202+
.setSplitGain(12d)
203+
.setThreshold(1d)
204+
.setOperator(Operator.LT)
205+
.setDefaultLeft(true)
206+
.setLeftChild(1)
207+
.setRightChild(2)
208+
)
209+
.addLeaf(1, 10)
210+
.addLeaf(2, 1)
211+
.build()
212+
)
213+
)
214+
.build()
215+
)
216+
)
217+
.validate(true)
218+
.build(),
219+
false
220+
)
221+
).actionGet();
222+
}
223+
}

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/LocalStateMachineLearning.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ protected SSLService getSslService() {
9595
});
9696
}
9797

98+
@Override
99+
public List<QuerySpec<?>> getQueries() {
100+
return mlPlugin.getQueries();
101+
}
102+
103+
@Override
104+
public List<RescorerSpec<?>> getRescorers() {
105+
return mlPlugin.getRescorers();
106+
}
107+
98108
@Override
99109
public List<AggregationSpec> getAggregations() {
100110
return mlPlugin.getAggregations();

0 commit comments

Comments
 (0)