Skip to content

Commit f964659

Browse files
Adding chunking settings to MistralService, GoogleAiStudioService, and HuggingFaceService (#113623)
* Adding chunking settings to MistralService, GoogleAiStudioService, and HuggingFaceService * Update docs/changelog/113623.yaml * Removing chunking settings from HuggingFaceElser model inputs
1 parent eaf2377 commit f964659

File tree

17 files changed

+992
-40
lines changed

17 files changed

+992
-40
lines changed

docs/changelog/113623.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 113623
2+
summary: "Adding chunking settings to `MistralService,` `GoogleAiStudioService,` and\
3+
\ `HuggingFaceService`"
4+
area: Machine Learning
5+
type: enhancement
6+
issues: []

server/src/main/java/org/elasticsearch/inference/ModelConfigurations.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,16 @@ public ModelConfigurations(String inferenceEntityId, TaskType taskType, String s
7474
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE);
7575
}
7676

77+
public ModelConfigurations(
78+
String inferenceEntityId,
79+
TaskType taskType,
80+
String service,
81+
ServiceSettings serviceSettings,
82+
ChunkingSettings chunkingSettings
83+
) {
84+
this(inferenceEntityId, taskType, service, serviceSettings, EmptyTaskSettings.INSTANCE, chunkingSettings);
85+
}
86+
7787
public ModelConfigurations(
7888
String inferenceEntityId,
7989
TaskType taskType,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1717
import org.elasticsearch.inference.ChunkingOptions;
18+
import org.elasticsearch.inference.ChunkingSettings;
1819
import org.elasticsearch.inference.InferenceServiceResults;
1920
import org.elasticsearch.inference.InputType;
2021
import org.elasticsearch.inference.Model;
@@ -23,6 +24,8 @@
2324
import org.elasticsearch.inference.SimilarityMeasure;
2425
import org.elasticsearch.inference.TaskType;
2526
import org.elasticsearch.rest.RestStatus;
27+
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
28+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
2629
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
2730
import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator;
2831
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -71,11 +74,19 @@ public void parseRequestConfig(
7174
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
7275
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
7376

77+
ChunkingSettings chunkingSettings = null;
78+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
79+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
80+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
81+
);
82+
}
83+
7484
GoogleAiStudioModel model = createModel(
7585
inferenceEntityId,
7686
taskType,
7787
serviceSettingsMap,
7888
taskSettingsMap,
89+
chunkingSettings,
7990
serviceSettingsMap,
8091
TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME),
8192
ConfigurationParseContext.REQUEST
@@ -97,6 +108,7 @@ private static GoogleAiStudioModel createModel(
97108
TaskType taskType,
98109
Map<String, Object> serviceSettings,
99110
Map<String, Object> taskSettings,
111+
ChunkingSettings chunkingSettings,
100112
@Nullable Map<String, Object> secretSettings,
101113
String failureMessage,
102114
ConfigurationParseContext context
@@ -117,6 +129,7 @@ private static GoogleAiStudioModel createModel(
117129
NAME,
118130
serviceSettings,
119131
taskSettings,
132+
chunkingSettings,
120133
secretSettings,
121134
context
122135
);
@@ -135,11 +148,17 @@ public GoogleAiStudioModel parsePersistedConfigWithSecrets(
135148
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
136149
Map<String, Object> secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS);
137150

151+
ChunkingSettings chunkingSettings = null;
152+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
153+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
154+
}
155+
138156
return createModelFromPersistent(
139157
inferenceEntityId,
140158
taskType,
141159
serviceSettingsMap,
142160
taskSettingsMap,
161+
chunkingSettings,
143162
secretSettingsMap,
144163
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
145164
);
@@ -150,6 +169,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
150169
TaskType taskType,
151170
Map<String, Object> serviceSettings,
152171
Map<String, Object> taskSettings,
172+
ChunkingSettings chunkingSettings,
153173
Map<String, Object> secretSettings,
154174
String failureMessage
155175
) {
@@ -158,6 +178,7 @@ private static GoogleAiStudioModel createModelFromPersistent(
158178
taskType,
159179
serviceSettings,
160180
taskSettings,
181+
chunkingSettings,
161182
secretSettings,
162183
failureMessage,
163184
ConfigurationParseContext.PERSISTENT
@@ -169,11 +190,17 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
169190
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
170191
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
171192

193+
ChunkingSettings chunkingSettings = null;
194+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
195+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
196+
}
197+
172198
return createModelFromPersistent(
173199
inferenceEntityId,
174200
taskType,
175201
serviceSettingsMap,
176202
taskSettingsMap,
203+
chunkingSettings,
177204
null,
178205
parsePersistedConfigErrorMsg(inferenceEntityId, NAME)
179206
);
@@ -245,11 +272,22 @@ protected void doChunkedInfer(
245272
GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model;
246273
var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents());
247274

248-
var batchedRequests = new EmbeddingRequestChunker(
249-
inputs.getInputs(),
250-
EMBEDDING_MAX_BATCH_SIZE,
251-
EmbeddingRequestChunker.EmbeddingType.FLOAT
252-
).batchRequestsWithListeners(listener);
275+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
276+
if (ChunkingSettingsFeatureFlag.isEnabled()) {
277+
batchedRequests = new EmbeddingRequestChunker(
278+
inputs.getInputs(),
279+
EMBEDDING_MAX_BATCH_SIZE,
280+
EmbeddingRequestChunker.EmbeddingType.FLOAT,
281+
googleAiStudioModel.getConfigurations().getChunkingSettings()
282+
).batchRequestsWithListeners(listener);
283+
} else {
284+
batchedRequests = new EmbeddingRequestChunker(
285+
inputs.getInputs(),
286+
EMBEDDING_MAX_BATCH_SIZE,
287+
EmbeddingRequestChunker.EmbeddingType.FLOAT
288+
).batchRequestsWithListeners(listener);
289+
}
290+
253291
for (var request : batchedRequests) {
254292
var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType);
255293
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import org.apache.http.client.utils.URIBuilder;
1111
import org.elasticsearch.core.Nullable;
12+
import org.elasticsearch.inference.ChunkingSettings;
1213
import org.elasticsearch.inference.EmptyTaskSettings;
1314
import org.elasticsearch.inference.InputType;
1415
import org.elasticsearch.inference.ModelConfigurations;
@@ -38,6 +39,7 @@ public GoogleAiStudioEmbeddingsModel(
3839
String service,
3940
Map<String, Object> serviceSettings,
4041
Map<String, Object> taskSettings,
42+
ChunkingSettings chunkingSettings,
4143
Map<String, Object> secrets,
4244
ConfigurationParseContext context
4345
) {
@@ -47,6 +49,7 @@ public GoogleAiStudioEmbeddingsModel(
4749
service,
4850
GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings, context),
4951
EmptyTaskSettings.INSTANCE,
52+
chunkingSettings,
5053
DefaultSecretSettings.fromMap(secrets)
5154
);
5255
}
@@ -62,10 +65,11 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
6265
String service,
6366
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
6467
TaskSettings taskSettings,
68+
ChunkingSettings chunkingSettings,
6569
@Nullable DefaultSecretSettings secrets
6670
) {
6771
super(
68-
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings),
72+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
6973
new ModelSecrets(secrets),
7074
serviceSettings
7175
);
@@ -98,6 +102,29 @@ public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, Google
98102
}
99103
}
100104

105+
// Should only be used directly for testing
106+
GoogleAiStudioEmbeddingsModel(
107+
String inferenceEntityId,
108+
TaskType taskType,
109+
String service,
110+
String uri,
111+
GoogleAiStudioEmbeddingsServiceSettings serviceSettings,
112+
TaskSettings taskSettings,
113+
ChunkingSettings chunkingsettings,
114+
@Nullable DefaultSecretSettings secrets
115+
) {
116+
super(
117+
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingsettings),
118+
new ModelSecrets(secrets),
119+
serviceSettings
120+
);
121+
try {
122+
this.uri = new URI(uri);
123+
} catch (URISyntaxException e) {
124+
throw new RuntimeException(e);
125+
}
126+
}
127+
101128
@Override
102129
public GoogleAiStudioEmbeddingsServiceSettings getServiceSettings() {
103130
return (GoogleAiStudioEmbeddingsServiceSettings) super.getServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99

1010
import org.elasticsearch.action.ActionListener;
1111
import org.elasticsearch.core.TimeValue;
12+
import org.elasticsearch.inference.ChunkingSettings;
1213
import org.elasticsearch.inference.InferenceServiceResults;
1314
import org.elasticsearch.inference.InputType;
1415
import org.elasticsearch.inference.Model;
1516
import org.elasticsearch.inference.ModelConfigurations;
1617
import org.elasticsearch.inference.ModelSecrets;
1718
import org.elasticsearch.inference.TaskType;
19+
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
20+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
1821
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
1922
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
2023
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
@@ -26,6 +29,7 @@
2629

2730
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
2831
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
32+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
2933
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
3034
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
3135

@@ -52,10 +56,18 @@ public void parseRequestConfig(
5256
try {
5357
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
5458

59+
ChunkingSettings chunkingSettings = null;
60+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
61+
chunkingSettings = ChunkingSettingsBuilder.fromMap(
62+
removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS)
63+
);
64+
}
65+
5566
var model = createModel(
5667
inferenceEntityId,
5768
taskType,
5869
serviceSettingsMap,
70+
chunkingSettings,
5971
serviceSettingsMap,
6072
TaskType.unsupportedTaskTypeErrorMsg(taskType, name()),
6173
ConfigurationParseContext.REQUEST
@@ -80,10 +92,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
8092
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
8193
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
8294

95+
ChunkingSettings chunkingSettings = null;
96+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
97+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
98+
}
99+
83100
return createModel(
84101
inferenceEntityId,
85102
taskType,
86103
serviceSettingsMap,
104+
chunkingSettings,
87105
secretSettingsMap,
88106
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
89107
ConfigurationParseContext.PERSISTENT
@@ -94,10 +112,16 @@ public HuggingFaceModel parsePersistedConfigWithSecrets(
94112
public HuggingFaceModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
95113
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
96114

115+
ChunkingSettings chunkingSettings = null;
116+
if (ChunkingSettingsFeatureFlag.isEnabled() && TaskType.TEXT_EMBEDDING.equals(taskType)) {
117+
chunkingSettings = ChunkingSettingsBuilder.fromMap(removeFromMapOrDefaultEmpty(config, ModelConfigurations.CHUNKING_SETTINGS));
118+
}
119+
97120
return createModel(
98121
inferenceEntityId,
99122
taskType,
100123
serviceSettingsMap,
124+
chunkingSettings,
101125
null,
102126
parsePersistedConfigErrorMsg(inferenceEntityId, name()),
103127
ConfigurationParseContext.PERSISTENT
@@ -108,6 +132,7 @@ protected abstract HuggingFaceModel createModel(
108132
String inferenceEntityId,
109133
TaskType taskType,
110134
Map<String, Object> serviceSettings,
135+
ChunkingSettings chunkingSettings,
111136
Map<String, Object> secretSettings,
112137
String failureMessage,
113138
ConfigurationParseContext context

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1717
import org.elasticsearch.inference.ChunkingOptions;
18+
import org.elasticsearch.inference.ChunkingSettings;
1819
import org.elasticsearch.inference.InputType;
1920
import org.elasticsearch.inference.Model;
2021
import org.elasticsearch.inference.SimilarityMeasure;
2122
import org.elasticsearch.inference.TaskType;
2223
import org.elasticsearch.rest.RestStatus;
24+
import org.elasticsearch.xpack.core.inference.ChunkingSettingsFeatureFlag;
2325
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
2426
import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator;
2527
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
@@ -48,6 +50,7 @@ protected HuggingFaceModel createModel(
4850
String inferenceEntityId,
4951
TaskType taskType,
5052
Map<String, Object> serviceSettings,
53+
ChunkingSettings chunkingSettings,
5154
@Nullable Map<String, Object> secretSettings,
5255
String failureMessage,
5356
ConfigurationParseContext context
@@ -58,6 +61,7 @@ protected HuggingFaceModel createModel(
5861
taskType,
5962
NAME,
6063
serviceSettings,
64+
chunkingSettings,
6165
secretSettings,
6266
context
6367
);
@@ -111,11 +115,22 @@ protected void doChunkedInfer(
111115
var huggingFaceModel = (HuggingFaceModel) model;
112116
var actionCreator = new HuggingFaceActionCreator(getSender(), getServiceComponents());
113117

114-
var batchedRequests = new EmbeddingRequestChunker(
115-
inputs.getInputs(),
116-
EMBEDDING_MAX_BATCH_SIZE,
117-
EmbeddingRequestChunker.EmbeddingType.FLOAT
118-
).batchRequestsWithListeners(listener);
118+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests;
119+
if (ChunkingSettingsFeatureFlag.isEnabled()) {
120+
batchedRequests = new EmbeddingRequestChunker(
121+
inputs.getInputs(),
122+
EMBEDDING_MAX_BATCH_SIZE,
123+
EmbeddingRequestChunker.EmbeddingType.FLOAT,
124+
huggingFaceModel.getConfigurations().getChunkingSettings()
125+
).batchRequestsWithListeners(listener);
126+
} else {
127+
batchedRequests = new EmbeddingRequestChunker(
128+
inputs.getInputs(),
129+
EMBEDDING_MAX_BATCH_SIZE,
130+
EmbeddingRequestChunker.EmbeddingType.FLOAT
131+
).batchRequestsWithListeners(listener);
132+
}
133+
119134
for (var request : batchedRequests) {
120135
var action = huggingFaceModel.accept(actionCreator);
121136
action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener());

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.elasticsearch.core.TimeValue;
1717
import org.elasticsearch.inference.ChunkedInferenceServiceResults;
1818
import org.elasticsearch.inference.ChunkingOptions;
19+
import org.elasticsearch.inference.ChunkingSettings;
1920
import org.elasticsearch.inference.InferenceServiceResults;
2021
import org.elasticsearch.inference.InputType;
2122
import org.elasticsearch.inference.Model;
@@ -56,6 +57,7 @@ protected HuggingFaceModel createModel(
5657
String inferenceEntityId,
5758
TaskType taskType,
5859
Map<String, Object> serviceSettings,
60+
ChunkingSettings chunkingSettings,
5961
@Nullable Map<String, Object> secretSettings,
6062
String failureMessage,
6163
ConfigurationParseContext context

0 commit comments

Comments
 (0)