Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugin/core/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
exports org.elasticsearch.xpack.core.watcher.watch;
exports org.elasticsearch.xpack.core.watcher;
exports org.elasticsearch.xpack.core.common.chunks;
exports org.elasticsearch.xpack.core.inference.chunking;

provides org.elasticsearch.action.admin.cluster.node.info.ComponentVersionNumber
with
Expand Down

Large diffs are not rendered by default.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to make this PR as small as possible, I surgically moved only those methods necessary from ServiceUtils and kept callers there for actual implementation messages. I did not move tests either. If I move things more aggressively this PR gets exponentially bigger.

Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference;

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.core.Strings;

import java.util.Arrays;
import java.util.EnumSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.elasticsearch.core.Strings.format;

public class InferenceUtils {

private InferenceUtils() {}

/**
* Remove the object from the map and cast to the expected type.
* If the object cannot be cast to type and error is added to the
* {@code validationException} parameter
*
* @param sourceMap Map containing fields
* @param key The key of the object to remove
* @param type The expected type of the removed object
* @param validationException If the value is not of type {@code type}
* @return {@code null} if not present else the object cast to type T
* @param <T> The expected type
*/
@SuppressWarnings("unchecked")
public static <T> T removeAsType(Map<String, Object> sourceMap, String key, Class<T> type, ValidationException validationException) {
if (sourceMap == null) {
validationException.addValidationError(Strings.format("Encountered a null input map while parsing field [%s]", key));
return null;
}

Object o = sourceMap.remove(key);
if (o == null) {
return null;
}

if (type.isAssignableFrom(o.getClass())) {
return (T) o;
} else {
validationException.addValidationError(invalidTypeErrorMsg(key, o, type.getSimpleName()));
return null;
}
}

public static String extractOptionalString(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException
) {
int initialValidationErrorCount = validationException.validationErrors().size();
String optionalField = removeAsType(map, settingName, String.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
// new validation error occurred
return null;
}

if (optionalField != null && optionalField.isEmpty()) {
validationException.addValidationError(mustBeNonEmptyString(settingName, scope));
}

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

return optionalField;
}

public static Integer extractRequiredPositiveInteger(
Map<String, Object> map,
String settingName,
String scope,
ValidationException validationException
) {
int initialValidationErrorCount = validationException.validationErrors().size();
Integer field = InferenceUtils.removeAsType(map, settingName, Integer.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

if (field == null) {
validationException.addValidationError(InferenceUtils.missingSettingErrorMsg(settingName, scope));
} else if (field <= 0) {
validationException.addValidationError(InferenceUtils.mustBeAPositiveIntegerErrorMessage(settingName, scope, field));
}

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

return field;
}

public static Integer extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
Map<String, Object> map,
String settingName,
int minValue,
String scope,
ValidationException validationException
) {
Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException);

if (field != null && field < minValue) {
validationException.addValidationError(
InferenceUtils.mustBeGreaterThanOrEqualNumberErrorMessage(settingName, scope, field, minValue)
);
return null;
}

return field;
}

public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
Map<String, Object> map,
String settingName,
int maxValue,
String scope,
ValidationException validationException
) {
Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException);

if (field != null && field > maxValue) {
validationException.addValidationError(
InferenceUtils.mustBeLessThanOrEqualNumberErrorMessage(settingName, scope, field, maxValue)
);
}

return field;
}

@SuppressWarnings("unchecked")
public static <T> List<T> extractOptionalList(
Map<String, Object> map,
String settingName,
Class<T> type,
ValidationException validationException
) {
int initialValidationErrorCount = validationException.validationErrors().size();
var optionalField = InferenceUtils.removeAsType(map, settingName, List.class, validationException);

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

if (optionalField != null) {
for (Object o : optionalField) {
if (o.getClass().equals(type) == false) {
validationException.addValidationError(InferenceUtils.invalidTypeErrorMsg(settingName, o, "String"));
}
}
}

if (validationException.validationErrors().size() > initialValidationErrorCount) {
return null;
}

return (List<T>) optionalField;
}

public static <E extends Enum<E>> E extractOptionalEnum(
Map<String, Object> map,
String settingName,
String scope,
EnumConstructor<E> constructor,
EnumSet<E> validValues,
ValidationException validationException
) {
var enumString = extractOptionalString(map, settingName, scope, validationException);
if (enumString == null) {
return null;
}

try {
var createdEnum = constructor.apply(enumString);
validateEnumValue(createdEnum, validValues);

return createdEnum;
} catch (IllegalArgumentException e) {
var validValuesAsStrings = validValues.stream().map(value -> value.toString().toLowerCase(Locale.ROOT)).toArray(String[]::new);
validationException.addValidationError(invalidValue(settingName, scope, enumString, validValuesAsStrings));
}

return null;
}

private static <E extends Enum<E>> void validateEnumValue(E enumValue, EnumSet<E> validValues) {
if (validValues.contains(enumValue) == false) {
throw new IllegalArgumentException(Strings.format("Enum value [%s] is not one of the acceptable values", enumValue.toString()));
}
}

public static String mustBeNonEmptyString(String settingName, String scope) {
return Strings.format("[%s] Invalid value empty string. [%s] must be a non-empty string", scope, settingName);
}

public static String invalidValue(String settingName, String scope, String invalidType, String[] requiredValues) {
var copyOfRequiredValues = requiredValues.clone();
Arrays.sort(copyOfRequiredValues);

return Strings.format(
"[%s] Invalid value [%s] received. [%s] must be one of [%s]",
scope,
invalidType,
settingName,
String.join(", ", copyOfRequiredValues)
);
}

public static String invalidTypeErrorMsg(String settingName, Object foundObject, String expectedType) {
return Strings.format(
"field [%s] is not of the expected type. The value [%s] cannot be converted to a [%s]",
settingName,
foundObject,
expectedType
);
}

public static String missingSettingErrorMsg(String settingName, String scope) {
return Strings.format("[%s] does not contain the required setting [%s]", scope, settingName);
}

public static String mustBeGreaterThanOrEqualNumberErrorMessage(String settingName, String scope, double value, double minValue) {
return format("[%s] Invalid value [%s]. [%s] must be a greater than or equal to [%s]", scope, value, settingName, minValue);
}

public static String mustBeLessThanOrEqualNumberErrorMessage(String settingName, String scope, double value, double maxValue) {
return format("[%s] Invalid value [%s]. [%s] must be a less than or equal to [%s]", scope, value, settingName, maxValue);
}

public static String mustBeAPositiveIntegerErrorMessage(String settingName, String scope, int value) {
return format("[%s] Invalid value [%s]. [%s] must be a positive integer", scope, value, settingName);
}

/**
* Functional interface for creating an enum from a string.
* @param <E>
*/
@FunctionalInterface
public interface EnumConstructor<E extends Enum<E>> {
E apply(String name) throws IllegalArgumentException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.ChunkingStrategy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

public enum ChunkingSettingsOptions {
STRATEGY("strategy"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.Strings;
Expand All @@ -16,7 +16,7 @@
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.core.inference.InferenceUtils;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -81,15 +81,15 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
Integer maxChunkSize = InferenceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);

SeparatorGroup separatorGroup = ServiceUtils.extractOptionalEnum(
SeparatorGroup separatorGroup = InferenceUtils.extractOptionalEnum(
map,
ChunkingSettingsOptions.SEPARATOR_GROUP.toString(),
ModelConfigurations.CHUNKING_SETTINGS,
Expand All @@ -98,7 +98,7 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
validationException
);

List<String> separators = ServiceUtils.extractOptionalList(
List<String> separators = InferenceUtils.extractOptionalList(
map,
ChunkingSettingsOptions.SEPARATORS.toString(),
String.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
Expand All @@ -18,7 +18,7 @@
import org.elasticsearch.inference.ChunkingStrategy;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.core.inference.InferenceUtils;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -57,6 +57,10 @@ public Integer maxChunkSize() {
return maxChunkSize;
}

public int sentenceOverlap() {
return sentenceOverlap;
}

@Override
public void validate() {
ValidationException validationException = new ValidationException();
Expand Down Expand Up @@ -100,15 +104,15 @@ public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map)
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
Integer maxChunkSize = InferenceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);

Integer sentenceOverlap = ServiceUtils.removeAsType(
Integer sentenceOverlap = InferenceUtils.removeAsType(
map,
ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(),
Integer.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

package org.elasticsearch.xpack.inference.chunking;
package org.elasticsearch.xpack.core.inference.chunking;

import java.util.List;
import java.util.Locale;
Expand Down
Loading