Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,3 @@ tests:
- class: org.elasticsearch.discovery.ec2.Ec2DiscoveryTests
method: testFilterByMultipleTags
issue: https://github.com/elastic/elasticsearch/issues/125166
- class: org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiEmbeddingsActionTests
method: testExecute_ReturnsSuccessfulResponse
issue: https://github.com/elastic/elasticsearch/issues/125057
- class: org.elasticsearch.xpack.inference.external.request.jinaai.JinaAIEmbeddingsRequestTests
method: testCreateRequest_TaskSettingsInputType
issue: https://github.com/elastic/elasticsearch/issues/125059
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

package org.elasticsearch.inference;

import org.elasticsearch.common.Strings;

import java.util.Locale;

import static org.elasticsearch.core.Strings.format;
Expand Down Expand Up @@ -51,4 +53,8 @@ public static boolean isInternalTypeOrUnspecified(InputType inputType) {
public static boolean isSpecified(InputType inputType) {
return inputType != null && inputType != InputType.UNSPECIFIED;
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.embeddings.AlibabaCloudSearchEmbeddingsTaskSettings.invalidInputTypeMessage;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record AlibabaCloudSearchEmbeddingsRequestEntity(
List<String> input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.DIMENSIONS_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.INPUT_FIELD;
import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.USER_FIELD;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;

public record AzureAiStudioEmbeddingsRequestEntity(
List<String> input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record CohereEmbeddingsRequestEntity(
List<String> input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import java.util.Objects;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record GoogleAiStudioEmbeddingsRequestEntity(List<String> inputs, InputType inputType, String model, @Nullable Integer dimensions)
implements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings.invalidInputTypeMessage;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record GoogleVertexAiEmbeddingsRequestEntity(
List<String> inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings.invalidInputTypeMessage;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record JinaAIEmbeddingsRequestEntity(
List<String> input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.services.voyageai.embeddings.VoyageAIEmbeddingsTaskSettings.invalidInputTypeMessage;
import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record VoyageAIEmbeddingsRequestEntity(
List<String> input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -24,6 +23,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService.VALID_INPUT_TYPE_VALUES;

Expand Down Expand Up @@ -166,10 +166,6 @@ public int hashCode() {
return Objects.hash(inputType);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
AlibabaCloudSearchEmbeddingsTaskSettings newSettingsOnly = fromMap(new HashMap<>(newSettings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -24,6 +23,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService.VALID_INPUT_TYPE_VALUES;
Expand Down Expand Up @@ -180,10 +180,6 @@ public int hashCode() {
return Objects.hash(inputType, returnToken);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
AlibabaCloudSearchSparseTaskSettings updatedSettings = fromMap(new HashMap<>(newSettings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -26,6 +25,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.cohere.CohereService.VALID_INPUT_TYPE_VALUES;
import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.TRUNCATE;
Expand Down Expand Up @@ -193,10 +193,6 @@ public int hashCode() {
return Objects.hash(inputType, truncation);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
CohereEmbeddingsTaskSettings updatedSettings = CohereEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -167,10 +165,6 @@ public int hashCode() {
return Objects.hash(returnDocuments, topNDocumentsOnly, maxChunksPerDoc);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

public Boolean getDoesReturnDocuments() {
return returnDocuments;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -24,6 +23,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiService.VALID_INPUT_TYPE_VALUES;
Expand Down Expand Up @@ -171,10 +171,6 @@ public int hashCode() {
return Objects.hash(autoTruncate, inputType);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
GoogleVertexAiEmbeddingsRequestTaskSettings updatedSettings = GoogleVertexAiEmbeddingsRequestTaskSettings.fromMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -164,10 +162,6 @@ public int hashCode() {
return Objects.hash(returnDocuments, topNDocumentsOnly, truncateInputTokens);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

public Boolean getDoesReturnDocuments() {
return returnDocuments;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -24,6 +23,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.jinaai.JinaAIService.VALID_INPUT_TYPE_VALUES;

Expand Down Expand Up @@ -160,10 +160,6 @@ public int hashCode() {
return Objects.hash(inputType);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
JinaAIEmbeddingsTaskSettings updatedSettings = JinaAIEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
Expand Down Expand Up @@ -142,10 +140,6 @@ public int hashCode() {
return Objects.hash(returnDocuments, topNDocumentsOnly);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

public Boolean getDoesReturnDocuments() {
return returnDocuments;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand All @@ -24,6 +23,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.voyageai.VoyageAIService.VALID_INPUT_TYPE_VALUES;
Expand Down Expand Up @@ -184,10 +184,6 @@ public int hashCode() {
return Objects.hash(inputType, truncation);
}

public static String invalidInputTypeMessage(InputType inputType) {
return Strings.format("received invalid input type value [%s]", inputType.toString());
}

@Override
public TaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
VoyageAIEmbeddingsTaskSettings updatedSettings = VoyageAIEmbeddingsTaskSettings.fromMap(new HashMap<>(newSettings));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
Expand Down Expand Up @@ -126,10 +127,10 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException {
assertThat(webServer.requests().get(0).getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo("apikey"));

var requestMap = entityAsMap(webServer.requests().get(0).getBody());
assertThat(requestMap.size(), is(inputType != null ? 3 : 2));
assertThat(requestMap.size(), is(InputType.isSpecified(inputType) ? 3 : 2));
assertThat(requestMap.get("input"), is(List.of("abc")));
assertThat(requestMap.get("user"), is("user"));
if (inputType != null) {
if (InputType.isSpecified(inputType)) {
assertThat(requestMap.get("input_type"), is(inputType.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ public void testCreateRequest_WithEntraIdDefined() throws IOException {
assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer " + entraId));

var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap.size(), equalTo(inputType != null ? 3 : 2));
assertThat(requestMap.size(), equalTo(InputType.isSpecified(inputType) ? 3 : 2));
assertThat(requestMap.get("input"), is(List.of(input)));
assertThat(requestMap.get("user"), is(user));
if (inputType != null) {
if (InputType.isSpecified(inputType)) {
assertThat(requestMap.get("input_type"), is(inputType.toString()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ public void testCreateRequest_AllOptionsDefined() throws IOException {
}

public void testCreateRequest_TaskSettingsInputType() throws IOException {
var inputType = InputTypeTests.randomWithNull();
var inputType = InputTypeTests.randomWithoutUnspecified();
var request = createRequest(
List.of("abc"),
null,
Expand Down