Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugin/core/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
exports org.elasticsearch.xpack.core.watcher.watch;
exports org.elasticsearch.xpack.core.watcher;
exports org.elasticsearch.xpack.core.common.chunks;
exports org.elasticsearch.xpack.core.inference.chunking;

provides org.elasticsearch.action.admin.cluster.node.info.ComponentVersionNumber
with
Expand Down

Large diffs are not rendered by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to make this PR as small as possible, I surgically moved only those methods necessary from ServiceUtils and kept callers there for actual implementation messages. I did not move tests either. If I move things more aggressively this PR gets exponentially bigger.

Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.Strings;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.core.Strings.format;

public class InferenceUtils {

private InferenceUtils() {}

/**
* Remove the object from the map and cast to the expected type.
* If the object cannot be cast to type and error is added to the
* {@code validationException} parameter
*
* @param sourceMap Map containing fields
* @param key The key of the object to remove
* @param type The expected type of the removed object
* @param validationException If the value is not of type {@code type}
* @return {@code null} if not present else the object cast to type T
* @param <T> The expected type
*/
@SuppressWarnings("unchecked")
public static <T> T removeAsType(Map<String, Object> sourceMap, String key, Class<T> type, ValidationException validationException) {
if (sourceMap == null) {
validationException.addValidationError(Strings.format("Encountered a null input map while parsing field [%s]", key));
return null;
}

Object o = sourceMap.remove(key);
if (o == null) {
return null;
}

if (type.isAssignableFrom(o.getClass())) {
return (T) o;
} else {
validationException.addValidationError(invalidTypeErrorMsg(key, o, type.getSimpleName()));
return null;
}
}

public static String extractOptionalString(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException
) {
int initialValidationErrorCount = validationException.validationErrors().size();
String optionalField = removeAsType(map, settingName, String.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
// new validation error occurred
return null;
}

if (optionalField != null && optionalField.isEmpty()) {
validationException.addValidationError(mustBeNonEmptyString(settingName, scope));
}

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

return optionalField;
}

public static Integer extractRequiredPositiveInteger(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException
) {
int initialValidationErrorCount = validationException.validationErrors().size();
Integer field = InferenceUtils.removeAsType(map, settingName, Integer.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

if (field == null) {
validationException.addValidationError(InferenceUtils.missingSettingErrorMsg(settingName, scope));
} else if (field <= 0) {
validationException.addValidationError(InferenceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, field));
}

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

return field;
}

public static Integer extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
Map<String, Object> map,
String settingName,
int minValue,
String scope,
ValidationException validationException
) {
Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException);

if (field != null && field < minValue) {
validationException.addValidationError(
InferenceUtils.mustBeGreaterThanOrEqualNumberErrorMessage(settingName, scope, field, minValue)
);
return null;
}

return field;
}

public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
Map<String, Object> map,
String settingName,
int maxValue,
String scope,
ValidationException validationException
) {
Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException);

if (field != null && field > maxValue) {
validationException.addValidationError(
InferenceUtils.mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, field, maxValue)
);
}

return field;
}

@SuppressWarnings("unchecked")
public static <T> List<T> extractOptionalList(
Map<String, Object> map,
String settingName,
Class<T> type,
ValidationException validationException
) {
int initialValidationErrorCount = validationException.validationErrors().size();
var optionalField = InferenceUtils.removeAsType(map, settingName, List.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

if (optionalField != null) {
for (Object o : optionalField) {
if (o.getClass().equals(type) == false) {
validationException.addValidationError(InferenceUtils.invalidTypeErrorMsg(settingName, o, "String"));
}
}
}

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

return (List<T>) optionalField;
}

public static <E extends Enum<E>> E extractOptionalEnum(
Map<String, Object> map,
String settingName,
String scope,
EnumConstructor<E> constructor,
EnumSet<E> validValues,
ValidationException validationException
) {
var enumString = extractOptionalString(map, settingName, scope, validationException);
if (enumString == null) {
return null;
}

try {
var createdEnum = constructor.apply(enumString);
validateEnumValue(createdEnum, validValues);

return createdEnum;
} catch (IllegalArgumentException e) {
var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
validationException.addValidationError(invalidValue(settingName, scope, enumString, validValuesAsStrings));
}

return null;
}

private static <E extends Enum<E>> void validateEnumValue(E enumValue, EnumSet<E> validValues) {
if (validValues.contains(enumValue) == false) {
throw new IllegalArgumentException(Strings.format("Enum value [%s] is not one of the acceptable values", enumValue.toString()));
}
}

public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}

public static String invalidValue(String settingName, String scope, String invalidType, String[] requiredValues) {
var copyOfRequiredValues = requiredValues.clone();
Arrays.sort(copyOfRequiredValues);

return Strings.format(
"[%s] Invalid value [%s] received. [%s] must be one of [%s]",
scope,
invalidType,
settingName,
String.join(", ", copyOfRequiredValues)
);
}

public static String invalidTypeErrorMsg(String settingName, Object foundObject, String expectedType) {
return Strings.format(
"field [%s] is not of the expected type. The value [%s] cannot be converted to a [%s]",
settingName,
foundObject,
expectedType
);
}

public static String missingSettingErrorMsg(String settingName, String scope) {
return Strings.format("[%s] does not contain the required setting [%s]", scope, settingName);
}

public static String mustBeGreaterThanOrEqualNumberErrorMessage(String settingName, String scope, double value, double minValue) {
return format("[%s] Invalid value [%s]. [%s] must be a greater than or equal to [%s]", scope, value, settingName, minValue);
}

public static String mustBeLessThanOrEqualNumberErrorMessage(String settingName, String scope, double value, double maxValue) {
return format("[%s] Invalid value [%s]. [%s] must be a less than or equal to [%s]", scope, value, settingName, maxValue);
}

public static String mustBeAPositiveIntegerErrorMessage(String settingName, String scope, int value) {
return format("[%s] Invalid value [%s]. [%s] must be a positive integer", scope, value, settingName);
}

/**
* Functional interface for creating an enum from a string.
* @param <E>
*/
@FunctionalInterface
public interface EnumConstructor<E extends Enum<E>> {
E apply(String name) throws IllegalArgumentException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

public enum ChunkingSettingsOptions {
STRATEGY("strategy"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
Expand All @@ -16,7 +16,7 @@
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.core.inference.InferenceUtils;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -81,15 +81,15 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
Integer maxChunkSize = InferenceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);

SeparatorGroup separatorGroup = ServiceUtils.extractOptionalEnum(
SeparatorGroup separatorGroup = InferenceUtils.extractOptionalEnum(
map,
ChunkingSettingsOptions.SEPARATOR_GROUP.toString(),
ModelConfigurations.CHUNKING_SETTINGS,
Expand All @@ -98,7 +98,7 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
validationException
);

List<String> separators = ServiceUtils.extractOptionalList(
List<String> separators = InferenceUtils.extractOptionalList(
map,
ChunkingSettingsOptions.SEPARATORS.toString(),
String.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
Expand All @@ -18,7 +18,7 @@
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.core.inference.InferenceUtils;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -57,6 +57,10 @@ public Integer maxChunkSize() {
return maxChunkSize;
}

public int sentenceOverlap() {
return sentenceOverlap;
}

@Override
public void validate() {
ValidationException validationException = new ValidationException();
Expand Down Expand Up @@ -100,15 +104,15 @@ public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map)
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
Integer maxChunkSize = InferenceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);

Integer sentenceOverlap = ServiceUtils.removeAsType(
Integer sentenceOverlap = InferenceUtils.removeAsType(
map,
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(),
Integer.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import java.util.List;
import java.util.Locale;
Expand Down
Loading