Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ static TransportVersion def(int id) {
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED_BACKPORT_8_X = def(8_840_0_01);
public static final TransportVersion ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion COHERE_BIT_EMBEDDING_TYPE_SUPPORT_ADDED = def(9_001_0_00);
public static final TransportVersion COHERE_EMBEDDING_TYPES_SUPPORT_ADDED = def(9_003_0_00);

/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;

Expand All @@ -30,7 +31,7 @@ public class CohereEmbeddingsRequest extends CohereRequest {
private final List<String> input;
private final CohereEmbeddingsTaskSettings taskSettings;
private final String model;
private final CohereEmbeddingType embeddingType;
private final EnumSet<CohereEmbeddingType> embeddingTypes;
private final String inferenceEntityId;

public CohereEmbeddingsRequest(List<String> input, CohereEmbeddingsModel embeddingsModel) {
Expand All @@ -40,7 +41,7 @@ public CohereEmbeddingsRequest(List<String> input, CohereEmbeddingsModel embeddi
this.input = Objects.requireNonNull(input);
taskSettings = embeddingsModel.getTaskSettings();
model = embeddingsModel.getServiceSettings().getCommonSettings().modelId();
embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType();
embeddingTypes = embeddingsModel.getServiceSettings().getEmbeddingTypes();
inferenceEntityId = embeddingsModel.getInferenceEntityId();
}

Expand All @@ -49,7 +50,7 @@ public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(account.uri());

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings, model, embeddingType)).getBytes(StandardCharsets.UTF_8)
Strings.toString(new CohereEmbeddingsRequestEntity(input, taskSettings, model, embeddingTypes)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings;

import java.io.IOException;
import java.util.EnumSet;
import java.util.List;
import java.util.Objects;

Expand All @@ -26,7 +27,7 @@ public record CohereEmbeddingsRequestEntity(
List<String> input,
CohereEmbeddingsTaskSettings taskSettings,
@Nullable String model,
@Nullable CohereEmbeddingType embeddingType
@Nullable EnumSet<CohereEmbeddingType> embeddingTypes
) implements ToXContentObject {

private static final String SEARCH_DOCUMENT = "search_document";
Expand Down Expand Up @@ -54,8 +55,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.getInputType()));
}

if (embeddingType != null) {
builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString()));
if (embeddingTypes != null) {
builder.field(EMBEDDING_TYPES_FIELD, embeddingTypes.stream().map(CohereEmbeddingType::toRequestString).toList());
}

if (taskSettings.getTruncation() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.HashMap;
Expand Down Expand Up @@ -73,7 +74,7 @@ public static <T> T removeAsType(Map<String, Object> sourceMap, String key, Clas

/**
* Remove the object from the map and cast to the expected type.
* If the object cannot be cast to type and error is added to the
* If the object cannot be cast to type an error is added to the
* {@code validationException} parameter
*
* @param sourceMap Map containing fields
Expand All @@ -98,6 +99,45 @@ public static <T> T removeAsType(Map<String, Object> sourceMap, String key, Clas
}
}

/**
* Remove a list of objects from the map and cast each entry to the expected type.
* If the object cannot be cast to a List or any of the entries cannot be cast
* to the type an error is added to the {@code validationException} parameter
*
* @param sourceMap Map containing fields
* @param key The key of the object to remove
* @param type The expected type of each list item in the removed object
* @param validationException If the value is not of type {@code type}
* @return {@code null} if not present else the object cast to type List of type T
* @param <T> The expected type
*/
@SuppressWarnings("unchecked")
public static <T> List<T> removeAsListOfType(
Map<String, Object> sourceMap,
String key,
Class<T> type,
ValidationException validationException
) {
Object o = sourceMap.remove(key);
if (o == null) {
return null;
}

if (List.class.isAssignableFrom(o.getClass())) {
// check each list entry
for (Object listItem : (List) o) {
if (type.isAssignableFrom(listItem.getClass()) == false) {
validationException.addValidationError(invalidTypeErrorMsg(key, listItem, type.getSimpleName()));
}
}

return (List<T>) o;
} else {
validationException.addValidationError(invalidTypeErrorMsg(key, o, List.class.getSimpleName()));
return null;
}
}

/**
* Remove the object from the map and cast to first assignable type in the expected types list.
* If the object cannot be cast to one of the types an error is added to the
Expand Down Expand Up @@ -254,6 +294,10 @@ public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}

public static String mustBeNonEmptyList(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty list. [%s] must be a non-empty list", scope, settingName);
}

public static String invalidTimeValueMsg(String timeValueStr, String settingName, String scope, String exceptionMsg) {
return Strings.format(
"[%s] Invalid time value [%s]. [%s] must be a valid time value string: %s",
Expand Down Expand Up @@ -401,6 +445,31 @@ public static String extractOptionalString(
return optionalField;
}

public static List<String> extractOptionalStringArray(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException
) {
int initialValidationErrorCount = validationException.validationErrors().size();
List<String> optionalField = ServiceUtils.removeAsListOfType(map, settingName, String.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
// new validation error occurred
return null;
}

if (optionalField != null && optionalField.isEmpty()) {
validationException.addValidationError(ServiceUtils.mustBeNonEmptyList(settingName, scope));
}

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

return optionalField;
}

public static Integer extractRequiredPositiveInteger(
Map<String, Object> map,
String settingName,
Expand Down Expand Up @@ -611,6 +680,37 @@ public static <E extends Enum<E>> E extractOptionalEnum(
return null;
}

public static <E extends Enum<E>> EnumSet<E> extractOptionalEnumSet(
Map<String, Object> map,
String settingName,
String scope,
EnumConstructor<E> constructor,
EnumSet<E> validValues,
ValidationException validationException
) {
var enumStringArray = extractOptionalStringArray(map, settingName, scope, validationException);
if (enumStringArray == null) {
return null;
}

var createdEnums = new ArrayList<E>();
for (String enumString : enumStringArray) {
try {
var createdEnum = constructor.apply(enumString);
validateEnumValue(createdEnum, validValues);
createdEnums.add(createdEnum);
} catch (IllegalArgumentException e) {
var validValuesAsStrings = validValues.stream()
.map(value -> value.toString().toLowerCase(Locale.ROOT))
.toArray(String[]::new);
validationException.addValidationError(invalidValue(settingName, scope, enumString, validValuesAsStrings));
return null;
}
}

return EnumSet.copyOf(createdEnums);
}

public static Boolean extractOptionalBoolean(Map<String, Object> map, String settingName, ValidationException validationException) {
return ServiceUtils.removeAsType(map, settingName, Boolean.class, validationException);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
serviceSettings.getCommonSettings().modelId(),
serviceSettings.getCommonSettings().rateLimitSettings()
),
serviceSettings.getEmbeddingType()
serviceSettings.getEmbeddingTypes()
);

return new CohereEmbeddingsModel(embeddingsModel, updatedServiceSettings);
Expand Down
Loading