Skip to content

Commit 3a447f1

Browse files
committed
Remove ElasticInferenceServiceRerankTaskSettings and override validateRerankParams in ElasticInferenceService
1 parent f43417b commit 3a447f1

File tree

11 files changed

+64
-316
lines changed

11 files changed

+64
-316
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekChatCompletionModel;
6363
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
6464
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
65-
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankTaskSettings;
6665
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6766
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
6867
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
@@ -705,12 +704,5 @@ private static void addElasticNamedWriteables(List<NamedWriteableRegistry.Entry>
705704
ElasticInferenceServiceRerankServiceSettings::new
706705
)
707706
);
708-
namedWriteables.add(
709-
new NamedWriteableRegistry.Entry(
710-
TaskSettings.class,
711-
ElasticInferenceServiceRerankTaskSettings.NAME,
712-
ElasticInferenceServiceRerankTaskSettings::new
713-
)
714-
);
715707
}
716708
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/elastic/rerank/ElasticInferenceServiceRerankRequest.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,22 @@ public class ElasticInferenceServiceRerankRequest extends ElasticInferenceServic
3030

3131
private final String query;
3232
private final List<String> documents;
33+
private final Integer topN;
3334
private final TraceContextHandler traceContextHandler;
3435
private final ElasticInferenceServiceRerankModel model;
3536

3637
public ElasticInferenceServiceRerankRequest(
3738
String query,
3839
List<String> documents,
40+
Integer topN,
3941
ElasticInferenceServiceRerankModel model,
4042
TraceContext traceContext,
4143
ElasticInferenceServiceRequestMetadata metadata
4244
) {
4345
super(metadata);
4446
this.query = query;
4547
this.documents = documents;
48+
this.topN = topN;
4649
this.model = Objects.requireNonNull(model);
4750
this.traceContextHandler = new TraceContextHandler(traceContext);
4851
}
@@ -51,12 +54,7 @@ public ElasticInferenceServiceRerankRequest(
5154
public HttpRequestBase createHttpRequestBase() {
5255
var httpPost = new HttpPost(getURI());
5356
var requestEntity = Strings.toString(
54-
new ElasticInferenceServiceRerankRequestEntity(
55-
query,
56-
documents,
57-
model.getServiceSettings().modelId(),
58-
model.getTaskSettings().getTopNDocumentsOnly()
59-
)
57+
new ElasticInferenceServiceRerankRequestEntity(query, documents, model.getServiceSettings().modelId(), topN)
6058
);
6159

6260
ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,18 @@ public void onNodeStarted() {
168168
authorizationHandler.init();
169169
}
170170

171+
@Override
172+
protected void validateRerankParameters(Boolean returnDocuments, Integer topN, ValidationException validationException) {
173+
if (returnDocuments != null) {
174+
validationException.addValidationError(
175+
org.elasticsearch.core.Strings.format(
176+
"Invalid return_documents [%s]. The return_documents option is not supported by this service",
177+
returnDocuments
178+
)
179+
);
180+
}
181+
}
182+
171183
/**
172184
* Only use this in tests.
173185
*

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ public ExecutableAction create(ElasticInferenceServiceRerankModel model) {
6868
(rerankInput) -> new ElasticInferenceServiceRerankRequest(
6969
rerankInput.getQuery(),
7070
rerankInput.getChunks(),
71+
rerankInput.getTopN(),
7172
model,
7273
traceContext,
7374
extractRequestMetadataFromThreadContext(threadPool.getThreadContext())

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.core.Nullable;
1212
import org.elasticsearch.inference.EmptySecretSettings;
13+
import org.elasticsearch.inference.EmptyTaskSettings;
1314
import org.elasticsearch.inference.ModelConfigurations;
1415
import org.elasticsearch.inference.ModelSecrets;
1516
import org.elasticsearch.inference.SecretSettings;
17+
import org.elasticsearch.inference.TaskSettings;
1618
import org.elasticsearch.inference.TaskType;
1719
import org.elasticsearch.rest.RestStatus;
1820
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
@@ -44,7 +46,7 @@ public ElasticInferenceServiceRerankModel(
4446
taskType,
4547
service,
4648
ElasticInferenceServiceRerankServiceSettings.fromMap(serviceSettings, context),
47-
ElasticInferenceServiceRerankTaskSettings.fromMap(taskSettings),
49+
EmptyTaskSettings.INSTANCE,
4850
EmptySecretSettings.INSTANCE,
4951
elasticInferenceServiceComponents
5052
);
@@ -55,7 +57,7 @@ public ElasticInferenceServiceRerankModel(
5557
TaskType taskType,
5658
String service,
5759
ElasticInferenceServiceRerankServiceSettings serviceSettings,
58-
@Nullable ElasticInferenceServiceRerankTaskSettings taskSettings,
60+
@Nullable TaskSettings taskSettings,
5961
@Nullable SecretSettings secretSettings,
6062
ElasticInferenceServiceComponents elasticInferenceServiceComponents
6163
) {
@@ -78,11 +80,6 @@ public ElasticInferenceServiceRerankServiceSettings getServiceSettings() {
7880
return (ElasticInferenceServiceRerankServiceSettings) super.getServiceSettings();
7981
}
8082

81-
@Override
82-
public ElasticInferenceServiceRerankTaskSettings getTaskSettings() {
83-
return (ElasticInferenceServiceRerankTaskSettings) super.getTaskSettings();
84-
}
85-
8683
public URI uri() {
8784
return uri;
8885
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankTaskSettings.java

Lines changed: 0 additions & 140 deletions
This file was deleted.

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/elastic/ElasticInferenceServiceRerankRequestTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,12 @@ private ElasticInferenceServiceRerankRequest createRequest(
7474
List<String> documents,
7575
Integer topN
7676
) {
77-
var rerankModel = ElasticInferenceServiceRerankModelTests.createModel(url, modelId, topN);
77+
var rerankModel = ElasticInferenceServiceRerankModelTests.createModel(url, modelId);
7878

7979
return new ElasticInferenceServiceRerankRequest(
8080
query,
8181
documents,
82+
topN,
8283
rerankModel,
8384
new TraceContext(randomAlphaOfLength(10), randomAlphaOfLength(10)),
8485
randomElasticInferenceServiceRequestMetadata()

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.ElasticsearchStatusException;
1212
import org.elasticsearch.action.ActionListener;
1313
import org.elasticsearch.action.support.PlainActionFuture;
14+
import org.elasticsearch.common.ValidationException;
1415
import org.elasticsearch.common.bytes.BytesArray;
1516
import org.elasticsearch.common.bytes.BytesReference;
1617
import org.elasticsearch.common.settings.Settings;
@@ -387,6 +388,39 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException
387388
verifyNoMoreInteractions(sender);
388389
}
389390

391+
public void testInfer_ThrowsValidationErrorForInvalidRerankParams() throws IOException {
392+
var sender = mock(Sender.class);
393+
394+
var factory = mock(HttpRequestSender.Factory.class);
395+
when(factory.createSender()).thenReturn(sender);
396+
397+
try (var service = createServiceWithMockSender()) {
398+
var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), "my-rerank-model-id");
399+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
400+
401+
var thrownException = expectThrows(
402+
ValidationException.class,
403+
() -> service.infer(
404+
model,
405+
"search query",
406+
Boolean.TRUE,
407+
10,
408+
List.of("doc1", "doc2", "doc3"),
409+
false,
410+
new HashMap<>(),
411+
InputType.SEARCH,
412+
InferenceAction.Request.DEFAULT_TIMEOUT,
413+
listener
414+
)
415+
);
416+
417+
assertThat(
418+
thrownException.getMessage(),
419+
is("Validation Failed: 1: Invalid return_documents [true]. The return_documents option is not supported by this service;")
420+
);
421+
}
422+
}
423+
390424
public void testInfer_ThrowsErrorWhenTaskTypeIsNotValid() throws IOException {
391425
var sender = mock(Sender.class);
392426

@@ -544,14 +578,14 @@ public void testRerank_SendsRerankRequest() throws IOException {
544578

545579
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
546580

547-
var model = ElasticInferenceServiceRerankModelTests.createModel(elasticInferenceServiceURL, modelId, topN);
581+
var model = ElasticInferenceServiceRerankModelTests.createModel(elasticInferenceServiceURL, modelId);
548582
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
549583

550584
service.infer(
551585
model,
552586
"search query",
553587
null,
554-
null,
588+
topN,
555589
List.of("doc1", "doc2", "doc3"),
556590
false,
557591
new HashMap<>(),

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,12 @@ public void testExecute_ReturnsSuccessfulResponse_ForRerankAction() throws IOExc
213213
var query = "query";
214214
var documents = List.of("document 1", "document 2", "document 3");
215215

216-
var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), modelId, topN);
216+
var model = ElasticInferenceServiceRerankModelTests.createModel(getUrl(webServer), modelId);
217217
var actionCreator = new ElasticInferenceServiceActionCreator(sender, createWithEmptySettings(threadPool), createTraceContext());
218218
var action = actionCreator.create(model);
219219

220220
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
221-
action.execute(new QueryAndDocsInputs(query, documents, null, null, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
221+
action.execute(new QueryAndDocsInputs(query, documents, null, topN, false), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
222222

223223
var result = listener.actionGet(TIMEOUT);
224224

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,20 @@
88
package org.elasticsearch.xpack.inference.services.elastic.rerank;
99

1010
import org.elasticsearch.inference.EmptySecretSettings;
11+
import org.elasticsearch.inference.EmptyTaskSettings;
1112
import org.elasticsearch.inference.TaskType;
1213
import org.elasticsearch.test.ESTestCase;
1314
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents;
1415

1516
public class ElasticInferenceServiceRerankModelTests extends ESTestCase {
1617

17-
public static ElasticInferenceServiceRerankModel createModel(String url, String modelId, Integer topN) {
18+
public static ElasticInferenceServiceRerankModel createModel(String url, String modelId) {
1819
return new ElasticInferenceServiceRerankModel(
1920
"id",
2021
TaskType.RERANK,
2122
"service",
2223
new ElasticInferenceServiceRerankServiceSettings(modelId, null),
23-
new ElasticInferenceServiceRerankTaskSettings(topN),
24+
EmptyTaskSettings.INSTANCE,
2425
EmptySecretSettings.INSTANCE,
2526
ElasticInferenceServiceComponents.of(url)
2627
);

0 commit comments

Comments
 (0)