Skip to content

Commit 638e5b6

Browse files
authored
Revert "Adding default endpoint for Elastic Rerank (#117939)" (#118221)
This reverts commit 54c320e.
1 parent 22e8f61 commit 638e5b6

File tree

11 files changed

+96
-185
lines changed

11 files changed

+96
-185
lines changed

docs/changelog/117939.yaml

Lines changed: 0 additions & 5 deletions
This file was deleted.

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ public void testGet() throws IOException {
5757

5858
var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
5959
assertDefaultE5Config(e5Model);
60-
61-
var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
62-
assertDefaultRerankConfig(rerankModel);
6360
}
6461

6562
@SuppressWarnings("unchecked")
@@ -128,42 +125,6 @@ private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
128125
assertDefaultChunkingSettings(modelConfig);
129126
}
130127

131-
@SuppressWarnings("unchecked")
132-
public void testInferDeploysDefaultRerank() throws IOException {
133-
var model = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
134-
assertDefaultRerankConfig(model);
135-
136-
var inputs = List.of("Hello World", "Goodnight moon");
137-
var query = "but why";
138-
var queryParams = Map.of("timeout", "120s");
139-
var results = infer(ElasticsearchInternalService.DEFAULT_RERANK_ID, TaskType.RERANK, inputs, query, queryParams);
140-
var embeddings = (List<Map<String, Object>>) results.get("rerank");
141-
assertThat(results.toString(), embeddings, hasSize(2));
142-
}
143-
144-
@SuppressWarnings("unchecked")
145-
private static void assertDefaultRerankConfig(Map<String, Object> modelConfig) {
146-
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_RERANK_ID, modelConfig.get("inference_id"));
147-
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
148-
assertEquals(modelConfig.toString(), TaskType.RERANK.toString(), modelConfig.get("task_type"));
149-
150-
var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
151-
assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(".rerank-v1"));
152-
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));
153-
154-
var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
155-
assertThat(
156-
modelConfig.toString(),
157-
adaptiveAllocations,
158-
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
159-
);
160-
161-
var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
162-
assertNull(chunkingSettings);
163-
var taskSettings = (Map<String, Object>) modelConfig.get("task_settings");
164-
assertThat(modelConfig.toString(), taskSettings, Matchers.is(Map.of("return_documents", true)));
165-
}
166-
167128
@SuppressWarnings("unchecked")
168129
private static void assertDefaultChunkingSettings(Map<String, Object> modelConfig) {
169130
var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
@@ -198,7 +159,6 @@ public void onFailure(Exception exception) {
198159
var request = createInferenceRequest(
199160
Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
200161
inputs,
201-
null,
202162
queryParams
203163
);
204164
client().performRequestAsync(request, listener);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ private List<Object> getInternalAsList(String endpoint) throws IOException {
336336

337337
protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
338338
var endpoint = Strings.format("_inference/%s", modelId);
339-
return inferInternal(endpoint, input, null, Map.of());
339+
return inferInternal(endpoint, input, Map.of());
340340
}
341341

342342
protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
@@ -352,7 +352,7 @@ protected Deque<ServerSentEvent> unifiedCompletionInferOnMockService(String mode
352352

353353
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
354354
var request = new Request("POST", endpoint);
355-
request.setJsonEntity(jsonBody(input, null));
355+
request.setJsonEntity(jsonBody(input));
356356

357357
return execAsyncCall(request);
358358
}
@@ -394,60 +394,33 @@ private String createUnifiedJsonBody(List<String> input, String role) throws IOE
394394

395395
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
396396
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
397-
return inferInternal(endpoint, input, null, Map.of());
397+
return inferInternal(endpoint, input, Map.of());
398398
}
399399

400400
protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input, Map<String, String> queryParameters)
401401
throws IOException {
402402
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
403-
return inferInternal(endpoint, input, null, queryParameters);
403+
return inferInternal(endpoint, input, queryParameters);
404404
}
405405

406-
protected Map<String, Object> infer(
407-
String modelId,
408-
TaskType taskType,
409-
List<String> input,
410-
String query,
411-
Map<String, String> queryParameters
412-
) throws IOException {
413-
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
414-
return inferInternal(endpoint, input, query, queryParameters);
415-
}
416-
417-
protected Request createInferenceRequest(
418-
String endpoint,
419-
List<String> input,
420-
@Nullable String query,
421-
Map<String, String> queryParameters
422-
) {
406+
protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
423407
var request = new Request("POST", endpoint);
424-
request.setJsonEntity(jsonBody(input, query));
408+
request.setJsonEntity(jsonBody(input));
425409
if (queryParameters.isEmpty() == false) {
426410
request.addParameters(queryParameters);
427411
}
428412
return request;
429413
}
430414

431-
private Map<String, Object> inferInternal(
432-
String endpoint,
433-
List<String> input,
434-
@Nullable String query,
435-
Map<String, String> queryParameters
436-
) throws IOException {
437-
var request = createInferenceRequest(endpoint, input, query, queryParameters);
415+
private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
416+
var request = createInferenceRequest(endpoint, input, queryParameters);
438417
var response = client().performRequest(request);
439418
assertOkOrCreated(response);
440419
return entityAsMap(response);
441420
}
442421

443-
private String jsonBody(List<String> input, @Nullable String query) {
444-
final StringBuilder bodyBuilder = new StringBuilder("{");
445-
446-
if (query != null) {
447-
bodyBuilder.append("\"query\":\"").append(query).append("\",");
448-
}
449-
450-
bodyBuilder.append("\"input\": [");
422+
private String jsonBody(List<String> input) {
423+
var bodyBuilder = new StringBuilder("{\"input\": [");
451424
for (var in : input) {
452425
bodyBuilder.append('"').append(in).append('"').append(',');
453426
}

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public void testCRUD() throws IOException {
4949
}
5050

5151
var getAllModels = getAllModels();
52-
int numModels = 12;
52+
int numModels = 11;
5353
assertThat(getAllModels, hasSize(numModels));
5454

5555
var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
@@ -537,7 +537,7 @@ private static String expectedResult(String input) {
537537
}
538538

539539
public void testGetZeroModels() throws IOException {
540-
var models = getModels("_all", TaskType.COMPLETION);
540+
var models = getModels("_all", TaskType.RERANK);
541541
assertThat(models, empty());
542542
}
543543
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,12 @@
6363
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
6464
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
6565
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
66+
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
6667
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
6768
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
6869
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
6970
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
7071
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
71-
import org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings;
7272
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
7373
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
7474
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
@@ -518,7 +518,9 @@ private static void addCustomElandWriteables(final List<NamedWriteableRegistry.E
518518
CustomElandInternalTextEmbeddingServiceSettings::new
519519
)
520520
);
521-
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, RerankTaskSettings.NAME, RerankTaskSettings::new));
521+
namedWriteables.add(
522+
new NamedWriteableRegistry.Entry(TaskSettings.class, CustomElandRerankTaskSettings.NAME, CustomElandRerankTaskSettings::new)
523+
);
522524
}
523525

524526
private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandRerankModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import java.util.HashMap;
1818
import java.util.Map;
1919

20-
import static org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings.RETURN_DOCUMENTS;
20+
import static org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings.RETURN_DOCUMENTS;
2121

2222
public class CustomElandRerankModel extends CustomElandModel {
2323

@@ -26,7 +26,7 @@ public CustomElandRerankModel(
2626
TaskType taskType,
2727
String service,
2828
CustomElandInternalServiceSettings serviceSettings,
29-
RerankTaskSettings taskSettings
29+
CustomElandRerankTaskSettings taskSettings
3030
) {
3131
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
3232
}
Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
/**
2727
* Defines the task settings for internal rerank service.
2828
*/
29-
public class RerankTaskSettings implements TaskSettings {
29+
public class CustomElandRerankTaskSettings implements TaskSettings {
3030

3131
public static final String NAME = "custom_eland_rerank_task_settings";
3232
public static final String RETURN_DOCUMENTS = "return_documents";
3333

34-
static final RerankTaskSettings DEFAULT_SETTINGS = new RerankTaskSettings(Boolean.TRUE);
34+
static final CustomElandRerankTaskSettings DEFAULT_SETTINGS = new CustomElandRerankTaskSettings(Boolean.TRUE);
3535

36-
public static RerankTaskSettings defaultsFromMap(Map<String, Object> map) {
36+
public static CustomElandRerankTaskSettings defaultsFromMap(Map<String, Object> map) {
3737
ValidationException validationException = new ValidationException();
3838

3939
if (map == null || map.isEmpty()) {
@@ -49,21 +49,21 @@ public static RerankTaskSettings defaultsFromMap(Map<String, Object> map) {
4949
returnDocuments = true;
5050
}
5151

52-
return new RerankTaskSettings(returnDocuments);
52+
return new CustomElandRerankTaskSettings(returnDocuments);
5353
}
5454

5555
/**
5656
* From map without any validation
5757
* @param map source map
5858
* @return Task settings
5959
*/
60-
public static RerankTaskSettings fromMap(Map<String, Object> map) {
60+
public static CustomElandRerankTaskSettings fromMap(Map<String, Object> map) {
6161
if (map == null || map.isEmpty()) {
6262
return DEFAULT_SETTINGS;
6363
}
6464

6565
Boolean returnDocuments = extractOptionalBoolean(map, RETURN_DOCUMENTS, new ValidationException());
66-
return new RerankTaskSettings(returnDocuments);
66+
return new CustomElandRerankTaskSettings(returnDocuments);
6767
}
6868

6969
/**
@@ -74,17 +74,20 @@ public static RerankTaskSettings fromMap(Map<String, Object> map) {
7474
* @param requestTaskSettings the settings passed in within the task_settings field of the request
7575
* @return Either {@code originalSettings} or {@code requestTaskSettings}
7676
*/
77-
public static RerankTaskSettings of(RerankTaskSettings originalSettings, RerankTaskSettings requestTaskSettings) {
77+
public static CustomElandRerankTaskSettings of(
78+
CustomElandRerankTaskSettings originalSettings,
79+
CustomElandRerankTaskSettings requestTaskSettings
80+
) {
7881
return requestTaskSettings.returnDocuments() != null ? requestTaskSettings : originalSettings;
7982
}
8083

8184
private final Boolean returnDocuments;
8285

83-
public RerankTaskSettings(StreamInput in) throws IOException {
86+
public CustomElandRerankTaskSettings(StreamInput in) throws IOException {
8487
this(in.readOptionalBoolean());
8588
}
8689

87-
public RerankTaskSettings(@Nullable Boolean doReturnDocuments) {
90+
public CustomElandRerankTaskSettings(@Nullable Boolean doReturnDocuments) {
8891
if (doReturnDocuments == null) {
8992
this.returnDocuments = true;
9093
} else {
@@ -130,7 +133,7 @@ public Boolean returnDocuments() {
130133
public boolean equals(Object o) {
131134
if (this == o) return true;
132135
if (o == null || getClass() != o.getClass()) return false;
133-
RerankTaskSettings that = (RerankTaskSettings) o;
136+
CustomElandRerankTaskSettings that = (CustomElandRerankTaskSettings) o;
134137
return Objects.equals(returnDocuments, that.returnDocuments);
135138
}
136139

@@ -141,7 +144,7 @@ public int hashCode() {
141144

142145
@Override
143146
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
144-
RerankTaskSettings updatedSettings = RerankTaskSettings.fromMap(new HashMap<>(newSettings));
147+
CustomElandRerankTaskSettings updatedSettings = CustomElandRerankTaskSettings.fromMap(new HashMap<>(newSettings));
145148
return of(this, updatedSettings);
146149
}
147150
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticRerankerModel.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.elasticsearch.ResourceNotFoundException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.inference.ChunkingSettings;
1213
import org.elasticsearch.inference.Model;
1314
import org.elasticsearch.inference.TaskType;
1415
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
@@ -21,9 +22,9 @@ public ElasticRerankerModel(
2122
TaskType taskType,
2223
String service,
2324
ElasticRerankerServiceSettings serviceSettings,
24-
RerankTaskSettings taskSettings
25+
ChunkingSettings chunkingSettings
2526
) {
26-
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
27+
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
2728
}
2829

2930
@Override

0 commit comments

Comments
 (0)