Skip to content

Commit e337ce6

Browse files
Adding ChunkingSettings logic and enabling ChunkingSettings for OpenAI embedding endpoints (#112074) (#113604)
* Adding ChunkingSettings logic and enabling ChunkingSettings for OpenAI embedding endpoints * Cleaning up naming in ChunkingSettings logic * Incrementing InferenceIndex version * Removing DefaultChunkingSettings, cleaning up chunking settings class and related tests, add chunking strategy to inference index * Adding check for up to date index mappings when creating an inference endpoint * Fixing transport version conflict * Adding validation for invalid chunking settings inputs and improving error messaging * Reverting SystemIndexMappingUpdateService changes and adding error messaging on mixed cluster exception
1 parent ac55e58 commit e337ce6

37 files changed

+1566
-34
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ static TransportVersion def(int id) {
224224
public static final TransportVersion ML_TELEMETRY_MEMORY_ADDED = def(8_748_00_0);
225225
public static final TransportVersion ILM_ADD_SEARCHABLE_SNAPSHOT_TOTAL_SHARDS_PER_NODE = def(8_749_00_0);
226226
public static final TransportVersion SEMANTIC_TEXT_SEARCH_INFERENCE_ID = def(8_750_00_0);
227+
public static final TransportVersion ML_INFERENCE_CHUNKING_SETTINGS = def(8_751_00_0);
227228

228229
/*
229230
* STOP! READ THIS FIRST! No, really,
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
13+
import org.elasticsearch.xcontent.ToXContentObject;
14+
15+
public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable {
16+
ChunkingStrategy getChunkingStrategy();
17+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
10+
package org.elasticsearch.inference;
11+
12+
import org.elasticsearch.common.Strings;
13+
14+
import java.util.EnumSet;
15+
16+
public enum ChunkingStrategy {
17+
WORD("word"),
18+
SENTENCE("sentence");
19+
20+
private final String chunkingStrategy;
21+
22+
ChunkingStrategy(String strategy) {
23+
this.chunkingStrategy = strategy;
24+
}
25+
26+
@Override
27+
public String toString() {
28+
return chunkingStrategy;
29+
}
30+
31+
public static ChunkingStrategy fromString(String strategy) {
32+
return EnumSet.allOf(ChunkingStrategy.class)
33+
.stream()
34+
.filter(cs -> cs.chunkingStrategy.equals(strategy))
35+
.findFirst()
36+
.orElseThrow(() -> new IllegalArgumentException(Strings.format("Invalid chunkingStrategy %s", strategy)));
37+
}
38+
}

server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ public class ModelConfigurations implements ToFilteredXContentObject, VersionedN
2929
public static final String SERVICE = "service";
3030
public static final String SERVICE_SETTINGS = "service_settings";
3131
public static final String TASK_SETTINGS = "task_settings";
32+
public static final String CHUNKING_SETTINGS = "chunking_settings";
3233
private static final String NAME = "inference_model";
3334

3435
public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
@@ -40,7 +41,8 @@ public static ModelConfigurations of(Model model, TaskSettings taskSettings) {
4041
model.getConfigurations().getTaskType(),
4142
model.getConfigurations().getService(),
4243
model.getServiceSettings(),
43-
taskSettings
44+
taskSettings,
45+
model.getConfigurations().getChunkingSettings()
4446
);
4547
}
4648

@@ -53,7 +55,8 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting
5355
model.getConfigurations().getTaskType(),
5456
model.getConfigurations().getService(),
5557
serviceSettings,
56-
model.getTaskSettings()
58+
model.getTaskSettings(),
59+
model.getConfigurations().getChunkingSettings()
5760
);
5861
}
5962

@@ -62,6 +65,7 @@ public static ModelConfigurations of(Model model, ServiceSettings serviceSetting
6265
private final String service;
6366
private final ServiceSettings serviceSettings;
6467
private final TaskSettings taskSettings;
68+
private final ChunkingSettings chunkingSettings;
6569

6670
/**
6771
* Allows no task settings to be defined. This will default to the {@link EmptyTaskSettings} object.
@@ -82,6 +86,23 @@ public ModelConfigurations(
8286
this.service = Objects.requireNonNull(service);
8387
this.serviceSettings = Objects.requireNonNull(serviceSettings);
8488
this.taskSettings = Objects.requireNonNull(taskSettings);
89+
this.chunkingSettings = null;
90+
}
91+
92+
public ModelConfigurations(
93+
String inferenceEntityId,
94+
TaskType taskType,
95+
String service,
96+
ServiceSettings serviceSettings,
97+
TaskSettings taskSettings,
98+
ChunkingSettings chunkingSettings
99+
) {
100+
this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId);
101+
this.taskType = Objects.requireNonNull(taskType);
102+
this.service = Objects.requireNonNull(service);
103+
this.serviceSettings = Objects.requireNonNull(serviceSettings);
104+
this.taskSettings = Objects.requireNonNull(taskSettings);
105+
this.chunkingSettings = chunkingSettings;
85106
}
86107

87108
public ModelConfigurations(StreamInput in) throws IOException {
@@ -90,6 +111,9 @@ public ModelConfigurations(StreamInput in) throws IOException {
90111
this.service = in.readString();
91112
this.serviceSettings = in.readNamedWriteable(ServiceSettings.class);
92113
this.taskSettings = in.readNamedWriteable(TaskSettings.class);
114+
this.chunkingSettings = in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)
115+
? in.readOptionalNamedWriteable(ChunkingSettings.class)
116+
: null;
93117
}
94118

95119
@Override
@@ -99,6 +123,9 @@ public void writeTo(StreamOutput out) throws IOException {
99123
out.writeString(service);
100124
out.writeNamedWriteable(serviceSettings);
101125
out.writeNamedWriteable(taskSettings);
126+
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CHUNKING_SETTINGS)) {
127+
out.writeOptionalNamedWriteable(chunkingSettings);
128+
}
102129
}
103130

104131
public String getInferenceEntityId() {
@@ -121,6 +148,10 @@ public TaskSettings getTaskSettings() {
121148
return taskSettings;
122149
}
123150

151+
public ChunkingSettings getChunkingSettings() {
152+
return chunkingSettings;
153+
}
154+
124155
@Override
125156
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
126157
builder.startObject();
@@ -133,6 +164,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
133164
builder.field(SERVICE, service);
134165
builder.field(SERVICE_SETTINGS, serviceSettings);
135166
builder.field(TASK_SETTINGS, taskSettings);
167+
if (chunkingSettings != null) {
168+
builder.field(CHUNKING_SETTINGS, chunkingSettings);
169+
}
136170
builder.endObject();
137171
return builder;
138172
}
@@ -149,6 +183,9 @@ public XContentBuilder toFilteredXContent(XContentBuilder builder, Params params
149183
builder.field(SERVICE, service);
150184
builder.field(SERVICE_SETTINGS, serviceSettings.getFilteredXContentObject());
151185
builder.field(TASK_SETTINGS, taskSettings);
186+
if (chunkingSettings != null) {
187+
builder.field(CHUNKING_SETTINGS, chunkingSettings);
188+
}
152189
builder.endObject();
153190
return builder;
154191
}

test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
*/
1818
public enum FeatureFlag {
1919
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
20-
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null);
20+
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
21+
CHUNKING_SETTINGS_ENABLED("es.inference_chunking_settings_feature_flag_enabled=true", Version.fromString("8.16.0"), null);
2122

2223
public final String systemProperty;
2324
public final Version from;
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference;
9+
10+
import org.elasticsearch.common.util.FeatureFlag;
11+
12+
/**
13+
* chunking_settings feature flag. When the feature is complete, this flag will be removed.
14+
*/
15+
public class ChunkingSettingsFeatureFlag {
16+
17+
private ChunkingSettingsFeatureFlag() {}
18+
19+
private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_chunking_settings");
20+
21+
public static boolean isEnabled() {
22+
return FEATURE_FLAG.isEnabled();
23+
}
24+
}

x-pack/plugin/inference/qa/mixed-cluster/src/javaRestTest/java/org/elasticsearch/xpack/inference/qa/mixed/OpenAIServiceMixedIT.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import java.util.List;
2020
import java.util.Map;
2121

22-
import static org.elasticsearch.xpack.inference.qa.mixed.MixedClusterSpecTestCase.bwcVersion;
22+
import static org.hamcrest.Matchers.containsString;
2323
import static org.hamcrest.Matchers.empty;
2424
import static org.hamcrest.Matchers.hasEntry;
2525
import static org.hamcrest.Matchers.hasSize;
@@ -29,6 +29,7 @@ public class OpenAIServiceMixedIT extends BaseMixedTestCase {
2929

3030
private static final String OPEN_AI_EMBEDDINGS_ADDED = "8.12.0";
3131
private static final String OPEN_AI_EMBEDDINGS_MODEL_SETTING_MOVED = "8.13.0";
32+
private static final String OPEN_AI_EMBEDDINGS_CHUNKING_SETTINGS_ADDED = "8.16.0";
3233
private static final String OPEN_AI_COMPLETIONS_ADDED = "8.14.0";
3334
private static final String MINIMUM_SUPPORTED_VERSION = "8.15.0";
3435

@@ -50,6 +51,7 @@ public static void shutdown() {
5051
openAiChatCompletionsServer.close();
5152
}
5253

54+
@AwaitsFix(bugUrl = "Backport #112074 to 8.16")
5355
@SuppressWarnings("unchecked")
5456
public void testOpenAiEmbeddings() throws IOException {
5557
var openAiEmbeddingsSupported = bwcVersion.onOrAfter(Version.fromString(OPEN_AI_EMBEDDINGS_ADDED));
@@ -64,7 +66,23 @@ public void testOpenAiEmbeddings() throws IOException {
6466
String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig();
6567
// queue a response as PUT will call the service
6668
openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse()));
67-
put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);
69+
70+
try {
71+
put(inferenceId, inferenceConfig, TaskType.TEXT_EMBEDDING);
72+
} catch (Exception e) {
73+
if (getOldClusterTestVersion().before(OPEN_AI_EMBEDDINGS_CHUNKING_SETTINGS_ADDED)) {
74+
// Chunking settings were added in 8.16.0. if the version is before that, an exception will be thrown if the index mapping
75+
// was created based on a mapping from an old node
76+
assertThat(
77+
e.getMessage(),
78+
containsString(
79+
"One or more nodes in your cluster does not support chunking_settings. "
80+
+ "Please update all nodes in your cluster to use chunking_settings."
81+
)
82+
);
83+
return;
84+
}
85+
}
6886

6987
var configs = (List<Map<String, Object>>) get(TaskType.TEXT_EMBEDDING, inferenceId).get("endpoints");
7088
assertThat(configs, hasSize(1));

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceIndex.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ private InferenceIndex() {}
2424

2525
public static final String INDEX_NAME = ".inference";
2626
public static final String INDEX_PATTERN = INDEX_NAME + "*";
27+
public static final String INDEX_ALIAS = ".inference-alias";
2728

2829
// Increment this version number when the mappings change
29-
private static final int INDEX_MAPPING_VERSION = 1;
30+
private static final int INDEX_MAPPING_VERSION = 2;
3031

3132
public static Settings settings() {
3233
return Settings.builder()
@@ -84,6 +85,50 @@ public static XContentBuilder mappings() {
8485
.startObject("properties")
8586
.endObject()
8687
.endObject()
88+
.startObject("chunking_settings")
89+
.field("dynamic", "false")
90+
.startObject("properties")
91+
.startObject("strategy")
92+
.field("type", "keyword")
93+
.endObject()
94+
.endObject()
95+
.endObject()
96+
.endObject()
97+
.endObject()
98+
.endObject();
99+
} catch (IOException e) {
100+
throw new UncheckedIOException("Failed to build mappings for index " + INDEX_NAME, e);
101+
}
102+
}
103+
104+
public static XContentBuilder mappingsV1() {
105+
try {
106+
return jsonBuilder().startObject()
107+
.startObject(SINGLE_MAPPING_NAME)
108+
.startObject("_meta")
109+
.field(SystemIndexDescriptor.VERSION_META_KEY, 1)
110+
.endObject()
111+
.field("dynamic", "strict")
112+
.startObject("properties")
113+
.startObject("model_id")
114+
.field("type", "keyword")
115+
.endObject()
116+
.startObject("task_type")
117+
.field("type", "keyword")
118+
.endObject()
119+
.startObject("service")
120+
.field("type", "keyword")
121+
.endObject()
122+
.startObject("service_settings")
123+
.field("dynamic", "false")
124+
.startObject("properties")
125+
.endObject()
126+
.endObject()
127+
.startObject("task_settings")
128+
.field("dynamic", "false")
129+
.startObject("properties")
130+
.endObject()
131+
.endObject()
87132
.endObject()
88133
.endObject()
89134
.endObject();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference;
99

1010
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
11+
import org.elasticsearch.inference.ChunkingSettings;
1112
import org.elasticsearch.inference.EmptySecretSettings;
1213
import org.elasticsearch.inference.EmptyTaskSettings;
1314
import org.elasticsearch.inference.InferenceResults;
@@ -26,6 +27,8 @@
2627
import org.elasticsearch.xpack.core.inference.results.RankedDocsResults;
2728
import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults;
2829
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
30+
import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings;
31+
import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings;
2932
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings;
3033
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettings;
3134
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettings;
@@ -108,6 +111,8 @@ public static List<NamedWriteableRegistry.Entry> getNamedWriteables() {
108111
// Empty default task settings
109112
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new));
110113

114+
addChunkingSettingsNamedWriteables(namedWriteables);
115+
111116
// Empty default secret settings
112117
namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new));
113118

@@ -444,6 +449,19 @@ private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteabl
444449
);
445450
}
446451

452+
private static void addChunkingSettingsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
453+
namedWriteables.add(
454+
new NamedWriteableRegistry.Entry(ChunkingSettings.class, WordBoundaryChunkingSettings.NAME, WordBoundaryChunkingSettings::new)
455+
);
456+
namedWriteables.add(
457+
new NamedWriteableRegistry.Entry(
458+
ChunkingSettings.class,
459+
SentenceBoundaryChunkingSettings.NAME,
460+
SentenceBoundaryChunkingSettings::new
461+
)
462+
);
463+
}
464+
447465
private static void addInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
448466
namedWriteables.add(
449467
new NamedWriteableRegistry.Entry(InferenceServiceResults.class, SparseEmbeddingResults.NAME, SparseEmbeddingResults::new)

0 commit comments

Comments
 (0)