Skip to content

Commit 3534ded

Browse files
jonathan-buttnerelasticsearchmachinebrendan-jugan-elastic
authored
[ML] Adding dynamic filtering for EIS configuration (#120235)
* Functionality for filtering task types based on acl info for EIS * Fixing compile and test errors * updating with chat_completion * Adding acl call * [CI] Auto commit changes from spotless * working run * Starting to rename * [CI] Auto commit changes from spotless * Writing some tests * rename authorizations endpoint and response fields * Fixing enabled task types bug * Adding timed listener tests * Fixing some test failures * Adding more tests and a mock gateway * Switch sparse embedding name from gateway * Adding supported streaming tasks tests * Trying to fix the javadoc * Removing commented code * Still fixing javadoc * Lets try a break this time * add AuthHandler, AuthRequest, and AuthResponseEntity tests * [CI] Auto commit changes from spotless * Adding tests * Speeding up test * Adding atomic ref * Refactoring * Addressing feedback * Forgot a fix * Removing todo --------- Co-authored-by: elasticsearchmachine <[email protected]> Co-authored-by: Brendan Jugan <[email protected]>
1 parent 6648a03 commit 3534ded

File tree

38 files changed

+2198
-374
lines changed

38 files changed

+2198
-374
lines changed

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ default void init(Client client) {}
7878
* Whether this service should be hidden from the API. Should be used for services
7979
* that are not ready to be used.
8080
*/
81-
default Boolean hideFromConfigurationApi() {
82-
return Boolean.FALSE;
81+
default boolean hideFromConfigurationApi() {
82+
return false;
8383
}
8484

8585
/**

server/src/main/java/org/elasticsearch/inference/InferenceServiceConfiguration.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@
2424
import org.elasticsearch.xcontent.XContentType;
2525

2626
import java.io.IOException;
27-
import java.util.ArrayList;
2827
import java.util.EnumSet;
2928
import java.util.HashMap;
3029
import java.util.List;
3130
import java.util.Map;
3231
import java.util.Objects;
33-
import java.util.stream.Collectors;
3432

3533
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
3634

@@ -80,14 +78,11 @@ public InferenceServiceConfiguration(StreamInput in) throws IOException {
8078
private static final ConstructingObjectParser<InferenceServiceConfiguration, Void> PARSER = new ConstructingObjectParser<>(
8179
"inference_service_configuration",
8280
true,
83-
args -> {
84-
List<String> taskTypes = (ArrayList<String>) args[2];
85-
return new InferenceServiceConfiguration.Builder().setService((String) args[0])
86-
.setName((String) args[1])
87-
.setTaskTypes(EnumSet.copyOf(taskTypes.stream().map(TaskType::fromString).collect(Collectors.toList())))
88-
.setConfigurations((Map<String, SettingsConfiguration>) args[3])
89-
.build();
90-
}
81+
args -> new InferenceServiceConfiguration.Builder().setService((String) args[0])
82+
.setName((String) args[1])
83+
.setTaskTypes((List<String>) args[2])
84+
.setConfigurations((Map<String, SettingsConfiguration>) args[3])
85+
.build()
9186
);
9287

9388
static {
@@ -195,6 +190,16 @@ public Builder setTaskTypes(EnumSet<TaskType> taskTypes) {
195190
return this;
196191
}
197192

193+
public Builder setTaskTypes(List<String> taskTypes) {
194+
var enumTaskTypes = EnumSet.noneOf(TaskType.class);
195+
196+
for (var supportedTaskTypeString : taskTypes) {
197+
enumTaskTypes.add(TaskType.fromStringOrStatusException(supportedTaskTypeString));
198+
}
199+
this.taskTypes = enumTaskTypes;
200+
return this;
201+
}
202+
198203
public Builder setConfigurations(Map<String, SettingsConfiguration> configurations) {
199204
this.configurations = configurations;
200205
return this;

server/src/test/java/org/elasticsearch/inference/InferenceServiceConfigurationTests.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,47 @@ public void testToXContent() throws IOException {
6666
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
6767
}
6868

69+
public void testToXContent_EmptyTaskTypes() throws IOException {
70+
String content = XContentHelper.stripWhitespace("""
71+
{
72+
"service": "some_provider",
73+
"name": "Some Provider",
74+
"task_types": [],
75+
"configurations": {
76+
"text_field_configuration": {
77+
"description": "Wow, this tooltip is useful.",
78+
"label": "Very important field",
79+
"required": true,
80+
"sensitive": true,
81+
"updatable": false,
82+
"type": "str"
83+
},
84+
"numeric_field_configuration": {
85+
"default_value": 3,
86+
"description": "Wow, this tooltip is useful.",
87+
"label": "Very important numeric field",
88+
"required": true,
89+
"sensitive": false,
90+
"updatable": true,
91+
"type": "int"
92+
}
93+
}
94+
}
95+
""");
96+
97+
InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes(
98+
new BytesArray(content),
99+
XContentType.JSON
100+
);
101+
boolean humanReadable = true;
102+
BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable);
103+
InferenceServiceConfiguration parsed;
104+
try (XContentParser parser = createParser(XContentType.JSON.xContent(), originalBytes)) {
105+
parsed = InferenceServiceConfiguration.fromXContent(parser);
106+
}
107+
assertToXContentEquivalent(originalBytes, toXContent(parsed, XContentType.JSON, humanReadable), XContentType.JSON);
108+
}
109+
69110
public void testToMap() {
70111
InferenceServiceConfiguration configField = InferenceServiceConfigurationTestUtils.getRandomServiceConfigurationField();
71112
Map<String, Object> configFieldAsMap = configField.toMap();

x-pack/plugin/inference/qa/inference-service-tests/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ dependencies {
44
javaRestTestImplementation project(path: xpackModule('core'))
55
javaRestTestImplementation project(path: xpackModule('inference'))
66
clusterPlugins project(':x-pack:plugin:inference:qa:test-service-plugin')
7+
// Added this to have access to MockWebServer within the tests
8+
javaRestTestImplementation(testArtifact(project(xpackModule('core'))))
79
}
810

911
tasks.named("javaRestTest").configure {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/CreateFromDeploymentIT.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public void testAttachToDeployment() throws IOException {
2929

3030
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
3131
var response = startMlNodeDeploymemnt(modelId, deploymentId);
32-
assertOkOrCreated(response);
32+
assertStatusOkOrCreated(response);
3333

3434
var inferenceId = "inference_on_existing_deployment";
3535
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
@@ -58,7 +58,7 @@ public void testAttachWithModelId() throws IOException {
5858

5959
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
6060
var response = startMlNodeDeploymemnt(modelId, deploymentId);
61-
assertOkOrCreated(response);
61+
assertStatusOkOrCreated(response);
6262

6363
var inferenceId = "inference_on_existing_deployment";
6464
var putModel = putModel(inferenceId, endpointConfig(modelId, deploymentId), TaskType.SPARSE_EMBEDDING);
@@ -93,7 +93,7 @@ public void testModelIdDoesNotMatch() throws IOException {
9393

9494
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
9595
var response = startMlNodeDeploymemnt(modelId, deploymentId);
96-
assertOkOrCreated(response);
96+
assertStatusOkOrCreated(response);
9797

9898
var inferenceId = "inference_on_existing_deployment";
9999
var e = expectThrows(
@@ -123,7 +123,7 @@ public void testNumAllocationsIsUpdated() throws IOException {
123123

124124
CustomElandModelIT.createMlNodeTextExpansionModel(modelId, client());
125125
var response = startMlNodeDeploymemnt(modelId, deploymentId);
126-
assertOkOrCreated(response);
126+
assertStatusOkOrCreated(response);
127127

128128
var inferenceId = "test_num_allocations_updated";
129129
var putModel = putModel(inferenceId, endpointConfig(deploymentId), TaskType.SPARSE_EMBEDDING);
@@ -145,7 +145,7 @@ public void testNumAllocationsIsUpdated() throws IOException {
145145
)
146146
);
147147

148-
assertOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
148+
assertStatusOkOrCreated(updateMlNodeDeploymemnt(deploymentId, 2));
149149

150150
var updatedServiceSettings = getModel(inferenceId).get("service_settings");
151151
assertThat(

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 14 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ public class InferenceBaseRestTest extends ESRestTestCase {
5252
.user("x_pack_rest_user", "x-pack-test-password")
5353
.feature(FeatureFlag.INFERENCE_UNIFIED_API_ENABLED)
5454
.build();
55-
5655
@ClassRule
5756
public static MlModelServer mlModelServer = new MlModelServer();
5857

@@ -175,20 +174,20 @@ static String mockDenseServiceModelConfig() {
175174
protected void deleteModel(String modelId) throws IOException {
176175
var request = new Request("DELETE", "_inference/" + modelId);
177176
var response = client().performRequest(request);
178-
assertOkOrCreated(response);
177+
assertStatusOkOrCreated(response);
179178
}
180179

181180
protected Response deleteModel(String modelId, String queryParams) throws IOException {
182181
var request = new Request("DELETE", "_inference/" + modelId + "?" + queryParams);
183182
var response = client().performRequest(request);
184-
assertOkOrCreated(response);
183+
assertStatusOkOrCreated(response);
185184
return response;
186185
}
187186

188187
protected void deleteModel(String modelId, TaskType taskType) throws IOException {
189188
var request = new Request("DELETE", Strings.format("_inference/%s/%s", taskType, modelId));
190189
var response = client().performRequest(request);
191-
assertOkOrCreated(response);
190+
assertStatusOkOrCreated(response);
192191
}
193192

194193
protected void putSemanticText(String endpointId, String indexName) throws IOException {
@@ -207,7 +206,7 @@ protected void putSemanticText(String endpointId, String indexName) throws IOExc
207206
""", endpointId);
208207
request.setJsonEntity(body);
209208
var response = client().performRequest(request);
210-
assertOkOrCreated(response);
209+
assertStatusOkOrCreated(response);
211210
}
212211

213212
protected void putSemanticText(String endpointId, String searchEndpointId, String indexName) throws IOException {
@@ -227,7 +226,7 @@ protected void putSemanticText(String endpointId, String searchEndpointId, Strin
227226
""", endpointId, searchEndpointId);
228227
request.setJsonEntity(body);
229228
var response = client().performRequest(request);
230-
assertOkOrCreated(response);
229+
assertStatusOkOrCreated(response);
231230
}
232231

233232
protected Map<String, Object> putModel(String modelId, String modelConfig, TaskType taskType) throws IOException {
@@ -260,7 +259,7 @@ protected Map<String, Object> putPipeline(String pipelineId, String modelId) thr
260259
protected void deletePipeline(String pipelineId) throws IOException {
261260
var request = new Request("DELETE", Strings.format("_ingest/pipeline/%s", pipelineId));
262261
var response = client().performRequest(request);
263-
assertOkOrCreated(response);
262+
assertStatusOkOrCreated(response);
264263
}
265264

266265
/**
@@ -275,15 +274,15 @@ Map<String, Object> putRequest(String endpoint, String body) throws IOException
275274
var request = new Request("PUT", endpoint);
276275
request.setJsonEntity(body);
277276
var response = client().performRequest(request);
278-
assertOkOrCreated(response);
277+
assertStatusOkOrCreated(response);
279278
return entityAsMap(response);
280279
}
281280

282281
Map<String, Object> postRequest(String endpoint, String body) throws IOException {
283282
var request = new Request("POST", endpoint);
284283
request.setJsonEntity(body);
285284
var response = client().performRequest(request);
286-
assertOkOrCreated(response);
285+
assertStatusOkOrCreated(response);
287286
return entityAsMap(response);
288287
}
289288

@@ -300,15 +299,15 @@ protected Map<String, Object> putE5TrainedModels() throws IOException {
300299

301300
request.setJsonEntity(body);
302301
var response = client().performRequest(request);
303-
assertOkOrCreated(response);
302+
assertStatusOkOrCreated(response);
304303
return entityAsMap(response);
305304
}
306305

307306
protected Map<String, Object> deployE5TrainedModels() throws IOException {
308307
var request = new Request("POST", "_ml/trained_models/.multilingual-e5-small/deployment/_start?wait_for=fully_allocated");
309308

310309
var response = client().performRequest(request);
311-
assertOkOrCreated(response);
310+
assertStatusOkOrCreated(response);
312311
return entityAsMap(response);
313312
}
314313

@@ -330,31 +329,13 @@ protected List<Map<String, Object>> getAllModels() throws IOException {
330329
return (List<Map<String, Object>>) getInternalAsMap("_inference/_all").get("endpoints");
331330
}
332331

333-
protected List<Object> getAllServices() throws IOException {
334-
var endpoint = Strings.format("_inference/_services");
335-
return getInternalAsList(endpoint);
336-
}
337-
338-
@SuppressWarnings("unchecked")
339-
protected List<Object> getServices(TaskType taskType) throws IOException {
340-
var endpoint = Strings.format("_inference/_services/%s", taskType);
341-
return getInternalAsList(endpoint);
342-
}
343-
344332
private Map<String, Object> getInternalAsMap(String endpoint) throws IOException {
345333
var request = new Request("GET", endpoint);
346334
var response = client().performRequest(request);
347-
assertOkOrCreated(response);
335+
assertStatusOkOrCreated(response);
348336
return entityAsMap(response);
349337
}
350338

351-
private List<Object> getInternalAsList(String endpoint) throws IOException {
352-
var request = new Request("GET", endpoint);
353-
var response = client().performRequest(request);
354-
assertOkOrCreated(response);
355-
return entityAsList(response);
356-
}
357-
358339
protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
359340
var endpoint = Strings.format("_inference/%s", modelId);
360341
return inferInternal(endpoint, input, null, Map.of());
@@ -475,7 +456,7 @@ private Map<String, Object> inferInternal(
475456
) throws IOException {
476457
var request = createInferenceRequest(endpoint, input, query, queryParameters);
477458
var response = client().performRequest(request);
478-
assertOkOrCreated(response);
459+
assertStatusOkOrCreated(response);
479460
return entityAsMap(response);
480461
}
481462

@@ -511,7 +492,7 @@ protected void assertNonEmptyInferenceResults(Map<String, Object> resultMap, int
511492
}
512493
}
513494

514-
protected static void assertOkOrCreated(Response response) throws IOException {
495+
static void assertStatusOkOrCreated(Response response) throws IOException {
515496
int statusCode = response.getStatusLine().getStatusCode();
516497
// Once EntityUtils.toString(entity) is called the entity cannot be reused.
517498
// Avoid that call with check here.
@@ -527,7 +508,7 @@ protected Map<String, Object> getTrainedModel(String inferenceEntityId) throws I
527508
var endpoint = Strings.format("_ml/trained_models/%s/_stats", inferenceEntityId);
528509
var request = new Request("GET", endpoint);
529510
var response = client().performRequest(request);
530-
assertOkOrCreated(response);
511+
assertStatusOkOrCreated(response);
531512
return entityAsMap(response);
532513
}
533514
}

0 commit comments

Comments
 (0)