Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/133718.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 133718
summary: Remove upper limit for chunking settings
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import java.util.Map;

public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable {

ChunkingStrategy getChunkingStrategy();

Map<String, Object> asMap();

/**
* @return The max chunk size specified, or null if not specified
*/
Integer maxChunkSize();
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public ChunkingStrategy getChunkingStrategy() {
return STRATEGY;
}

@Override
public Integer maxChunkSize() {
return null;
}

@Override
public String getWriteableName() {
return NAME;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public List<ChunkOffset> chunk(String input, ChunkingSettings chunkingSettings)
input,
new ChunkOffset(0, input.length()),
recursiveChunkingSettings.getSeparators(),
recursiveChunkingSettings.getMaxChunkSize(),
recursiveChunkingSettings.maxChunkSize(),
0
);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ public class RecursiveChunkingSettings implements ChunkingSettings {
public static final String NAME = "RecursiveChunkingSettings";
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.RECURSIVE;
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 10;
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;

private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
Expand Down Expand Up @@ -63,11 +62,10 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
MAX_CHUNK_SIZE_UPPER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);
Expand Down Expand Up @@ -105,7 +103,8 @@ public static RecursiveChunkingSettings fromMap(Map<String, Object> map) {
return new RecursiveChunkingSettings(maxChunkSize, separators);
}

public int getMaxChunkSize() {
@Override
public Integer maxChunkSize() {
return maxChunkSize;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ public class SentenceBoundaryChunkingSettings implements ChunkingSettings {
public static final String NAME = "SentenceBoundaryChunkingSettings";
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.SENTENCE;
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 20;
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;
private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
Expand All @@ -55,6 +54,11 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException {
}
}

@Override
public Integer maxChunkSize() {
return maxChunkSize;
}

@Override
public Map<String, Object> asMap() {
return Map.of(
Expand All @@ -77,11 +81,10 @@ public static SentenceBoundaryChunkingSettings fromMap(Map<String, Object> map)
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
MAX_CHUNK_SIZE_UPPER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ public class WordBoundaryChunkingSettings implements ChunkingSettings {
public static final String NAME = "WordBoundaryChunkingSettings";
private static final ChunkingStrategy STRATEGY = ChunkingStrategy.WORD;
private static final int MAX_CHUNK_SIZE_LOWER_LIMIT = 10;
private static final int MAX_CHUNK_SIZE_UPPER_LIMIT = 300;
private static final Set<String> VALID_KEYS = Set.of(
ChunkingSettingsOptions.STRATEGY.toString(),
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
Expand Down Expand Up @@ -61,7 +60,8 @@ public Map<String, Object> asMap() {
);
}

public int maxChunkSize() {
@Override
public Integer maxChunkSize() {
return maxChunkSize;
}

Expand All @@ -79,11 +79,10 @@ public static WordBoundaryChunkingSettings fromMap(Map<String, Object> map) {
);
}

Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerBetween(
Integer maxChunkSize = ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
map,
ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(),
MAX_CHUNK_SIZE_LOWER_LIMIT,
MAX_CHUNK_SIZE_UPPER_LIMIT,
ModelConfigurations.CHUNKING_SETTINGS,
validationException
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,25 @@ public static Integer extractRequiredPositiveIntegerLessThanOrEqualToMax(
return field;
}

public static Integer extractRequiredPositiveIntegerGreaterThanOrEqualToMin(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add some tests for this new method?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in f104004

Map<String, Object> map,
String settingName,
int minValue,
String scope,
ValidationException validationException
) {
Integer field = extractRequiredPositiveInteger(map, settingName, scope, validationException);

if (field != null && field < minValue) {
validationException.addValidationError(
ServiceUtils.mustBeGreaterThanOrEqualNumberErrorMessage(settingName, scope, field, minValue)
);
return null;
}

return field;
}

public static Integer extractRequiredPositiveIntegerBetween(
Map<String, Object> map,
String settingName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

public class ElserInternalModel extends ElasticsearchInternalModel {

// Ensure that inference endpoints based on ELSER don't go past its truncation window of 512 tokens
public static final int ELSER_MAX_WINDOW_SIZE = 300;

public ElserInternalModel(
String inferenceEntityId,
TaskType taskType,
Expand All @@ -21,6 +24,14 @@ public ElserInternalModel(
ChunkingSettings chunkingSettings
) {
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings);
if (chunkingSettings != null && chunkingSettings.maxChunkSize() != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the same check to MultilingualE5SmallModel so multilingual-e5-small will also pick up the restriction.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in f19f2ff

if (chunkingSettings.maxChunkSize() > ELSER_MAX_WINDOW_SIZE) throw new IllegalArgumentException(
"ELSER based models do not support chunk sizes larger than "
+ ELSER_MAX_WINDOW_SIZE
+ ". Requested chunk size: "
+ chunkingSettings.maxChunkSize()
);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

public class MultilingualE5SmallModel extends ElasticsearchInternalModel {

// Ensure that inference endpoints based on E5 small don't go past its window size
public static final int E5_SMALL_MAX_WINDOW_SIZE = 300;

public MultilingualE5SmallModel(
String inferenceEntityId,
TaskType taskType,
Expand All @@ -20,6 +23,15 @@ public MultilingualE5SmallModel(
ChunkingSettings chunkingSettings
) {
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
if (chunkingSettings != null && chunkingSettings.maxChunkSize() != null) {
if (chunkingSettings.maxChunkSize() > E5_SMALL_MAX_WINDOW_SIZE) throw new IllegalArgumentException(
serviceSettings.modelId()
+ " does not support chunk sizes larger than "
+ E5_SMALL_MAX_WINDOW_SIZE
+ ". Requested chunk size: "
+ chunkingSettings.maxChunkSize()
);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public void testFromMapValidSettingsWithSeparators() {

RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings);

assertEquals(maxChunkSize, settings.getMaxChunkSize());
assertEquals(maxChunkSize, (int) settings.maxChunkSize());
assertEquals(separators, settings.getSeparators());
}

Expand All @@ -39,7 +39,7 @@ public void testFromMapValidSettingsWithSeparatorGroup() {

RecursiveChunkingSettings settings = RecursiveChunkingSettings.fromMap(validSettings);

assertEquals(maxChunkSize, settings.getMaxChunkSize());
assertEquals(maxChunkSize, (int) settings.maxChunkSize());
assertEquals(separatorGroup.getSeparators(), settings.getSeparators());
}

Expand All @@ -49,12 +49,6 @@ public void testFromMapMaxChunkSizeTooSmall() {
assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
}

public void testFromMapMaxChunkSizeTooLarge() {
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(301, 500), Optional.empty(), Optional.empty());

assertThrows(ValidationException.class, () -> RecursiveChunkingSettings.fromMap(invalidSettings));
}

public void testFromMapInvalidSeparatorGroup() {
Map<String, Object> invalidSettings = buildChunkingSettingsMap(randomIntBetween(10, 300), Optional.of("invalid"), Optional.empty());

Expand Down Expand Up @@ -116,7 +110,7 @@ protected RecursiveChunkingSettings createTestInstance() {

@Override
protected RecursiveChunkingSettings mutateInstance(RecursiveChunkingSettings instance) throws IOException {
int maxChunkSize = randomValueOtherThan(instance.getMaxChunkSize(), () -> randomIntBetween(10, 300));
int maxChunkSize = randomValueOtherThan(instance.maxChunkSize(), () -> randomIntBetween(10, 300));
List<String> separators = instance.getSeparators();
separators.add(randomAlphaOfLength(1));
return new RecursiveChunkingSettings(maxChunkSize, separators);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalTimeValue;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveInteger;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveIntegerGreaterThanOrEqualToMin;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredPositiveIntegerLessThanOrEqualToMax;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
Expand Down Expand Up @@ -656,6 +657,69 @@ public void testExtractRequiredPositiveIntegerLessThanOrEqualToMax_AddsErrorWhen
assertThat(validation.validationErrors().get(1), is("[scope] does not contain the required setting [not_key]"));
}

public void testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_ReturnsValueWhenValueIsEqualToMin() {
testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_Successful(5, 5);
}

public void testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_ReturnsValueWhenValueIsGreaterThanToMin() {
testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_Successful(5, 6);
}

private void testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_Successful(int minValue, int actualValue) {
var validation = new ValidationException();
validation.addValidationError("previous error");
Map<String, Object> map = modifiableMap(Map.of("key", actualValue));
var parsedInt = extractRequiredPositiveIntegerGreaterThanOrEqualToMin(map, "key", minValue, "scope", validation);

assertThat(validation.validationErrors(), hasSize(1));
assertNotNull(parsedInt);
assertThat(parsedInt, is(actualValue));
assertTrue(map.isEmpty());
}

public void testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_AddsErrorWhenValueIsLessThanMin() {
testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_AddsError(
"key",
5,
4,
"[scope] Invalid value [4.0]. [key] must be a greater than or equal to [5.0]"
);
}

public void testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_AddsErrorWhenKeyIsMissing() {
testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_AddsError(
"not_key",
5,
-1,
"[scope] does not contain the required setting [not_key]"
);
}

public void testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_AddsErrorOnNegativeValue() {
testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_AddsError(
"key",
5,
-1,
"[scope] Invalid value [-1]. [key] must be a positive integer"
);
}

private void testExtractRequiredPositiveIntegerGreaterThanOrEqualToMin_AddsError(
String key,
int minValue,
int actualValue,
String error
) {
var validation = new ValidationException();
validation.addValidationError("previous error");
Map<String, Object> map = modifiableMap(Map.of("key", actualValue));
var parsedInt = extractRequiredPositiveIntegerGreaterThanOrEqualToMin(map, key, minValue, "scope", validation);

assertThat(validation.validationErrors(), hasSize(2));
assertNull(parsedInt);
assertThat(validation.validationErrors().get(1), containsString(error));
}

public void testExtractRequiredPositiveIntegerBetween_ReturnsValueWhenValueIsBetweenMinAndMax() {
var minValue = randomNonNegativeInt();
var maxValue = randomIntBetween(minValue + 2, minValue + 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment;
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentTests;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;

import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -48,4 +49,23 @@ public void testUpdateNumAllocation() {
equalTo(trainedModelAssignment.getAdaptiveAllocationsSettings())
);
}

public void testHugeChunkingSettings() {
Exception expectedException = expectThrows(
IllegalArgumentException.class,
() -> new ElserInternalModel(
"foo",
TaskType.SPARSE_EMBEDDING,
ElasticsearchInternalService.NAME,
new ElserInternalServiceSettings(new ElasticsearchInternalServiceSettings(null, 1, "elser", null, null)),
new ElserMlNodeTaskSettings(),
new SentenceBoundaryChunkingSettings(10000, 0)
)
);

assertThat(
expectedException.getMessage(),
equalTo("ELSER based models do not support chunk sizes larger than 300. Requested chunk size: 10000")
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;

import static org.hamcrest.Matchers.equalTo;

public class MultilingualE5SmallModelTests extends ESTestCase {

public void testHugeChunkingSettings() {
Exception expectedException = expectThrows(
IllegalArgumentException.class,
() -> new MultilingualE5SmallModel(
"foo",
TaskType.TEXT_EMBEDDING,
ElasticsearchInternalService.NAME,
MultilingualE5SmallInternalServiceSettings.defaultEndpointSettings(randomBoolean()),
new SentenceBoundaryChunkingSettings(10000, 0)
)
);

assertThat(
expectedException.getMessage(),
equalTo(".multilingual-e5-small does not support chunk sizes larger than 300. Requested chunk size: 10000")
);
}
}