Skip to content

Commit 4d3613b

Browse files
Addressing feedback
1 parent eb6def2 commit 4d3613b

File tree

15 files changed

+71
-98
lines changed

15 files changed

+71
-98
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ public StoreInferenceEndpointsAction() {
3535
}
3636

3737
public static class Request extends AcknowledgedRequest<Request> {
38-
private final List<? extends Model> models;
38+
private final List<Model> models;
3939

40-
public Request(List<? extends Model> models, TimeValue timeout) {
40+
public Request(List<Model> models, TimeValue timeout) {
4141
super(timeout, DEFAULT_ACK_TIMEOUT);
4242
this.models = Objects.requireNonNull(models);
4343
}
@@ -53,7 +53,7 @@ public void writeTo(StreamOutput out) throws IOException {
5353
out.writeCollection(models);
5454
}
5555

56-
public List<? extends Model> getModels() {
56+
public List<Model> getModels() {
5757
return models;
5858
}
5959

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ public class MockElasticInferenceServiceAuthorizationServer implements TestRule
2525
private final MockWebServer webServer = new MockWebServer();
2626

2727
public void enqueueAuthorizeAllModelsResponse() {
28-
var authResponse = getEisAuthorizationResponseWithMultipleEndpoints("ignored");
29-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authResponse.responseJson()));
28+
var authResponseBody = getEisAuthorizationResponseWithMultipleEndpoints("ignored").responseJson();
29+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authResponseBody));
3030
}
3131

3232
public String getUrl() {

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase {
6666

6767
private static final MockWebServer webServer = new MockWebServer();
6868
private static String gatewayUrl;
69-
private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse;
69+
private static String chatCompletionResponseBody;
7070

7171
private ModelRegistry modelRegistry;
7272
private AuthorizationTaskExecutor authorizationTaskExecutor;
@@ -76,7 +76,7 @@ public static void initClass() throws IOException {
7676
webServer.start();
7777
gatewayUrl = getUrl(webServer);
7878
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
79-
chatCompletionResponse = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl);
79+
chatCompletionResponseBody = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl).responseJson();
8080
}
8181

8282
@Before
@@ -122,7 +122,7 @@ protected Collection<Class<? extends Plugin>> getPlugins() {
122122
public void testCreatesEisChatCompletionEndpoint() throws Exception {
123123
assertNoAuthorizedEisEndpoints();
124124

125-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson()));
125+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
126126
restartPollingTaskAndWaitForAuthResponse();
127127

128128
assertChatCompletionEndpointExists();
@@ -227,7 +227,7 @@ static void cancelAuthorizationTask(AdminClient adminClient) throws Exception {
227227
public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception {
228228
assertNoAuthorizedEisEndpoints();
229229

230-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson()));
230+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
231231
restartPollingTaskAndWaitForAuthResponse();
232232

233233
assertChatCompletionEndpointExists();
@@ -261,7 +261,7 @@ static void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesMode
261261
public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception {
262262
assertNoAuthorizedEisEndpoints();
263263

264-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson()));
264+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
265265
restartPollingTaskAndWaitForAuthResponse();
266266

267267
assertChatCompletionEndpointExists();
@@ -273,8 +273,9 @@ public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Excep
273273
assertChatCompletionEndpointExists();
274274

275275
// Simulate that a text embedding model is now authorized
276-
var jinaEmbedResponse = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl);
277-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponse.responseJson()));
276+
var jinaEmbedResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisJinaEmbedAuthorizationResponse(gatewayUrl)
277+
.responseJson();
278+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(jinaEmbedResponseBody));
278279

279280
restartPollingTaskAndWaitForAuthResponse();
280281

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings;
2121
import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller;
2222
import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMSettings;
23-
import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntityTests;
2423
import org.junit.AfterClass;
2524
import org.junit.Before;
2625
import org.junit.BeforeClass;
@@ -55,14 +54,14 @@ public class AuthorizationTaskExecutorMultipleNodesIT extends ESIntegTestCase {
5554
private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]";
5655
private static final MockWebServer webServer = new MockWebServer();
5756
private static String gatewayUrl;
58-
private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse;
57+
private static String chatCompletionResponseBody;
5958

6059
@BeforeClass
6160
public static void initClass() throws IOException {
6261
webServer.start();
6362
gatewayUrl = getUrl(webServer);
6463
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EIS_EMPTY_RESPONSE));
65-
chatCompletionResponse = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl);
64+
chatCompletionResponseBody = getEisRainbowSprinklesAuthorizationResponse(gatewayUrl).responseJson();
6665
}
6766

6867
@Before
@@ -113,7 +112,7 @@ public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRun
113112
);
114113

115114
// queue a response that authorizes one model
116-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson()));
115+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
117116

118117
assertTrue("expected the node to shutdown properly", internalCluster().stopNode(nodeNameMapping.get(pollerTask.node())));
119118

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/CCMServiceIT.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class CCMServiceIT extends CCMSingleNodeIT {
4949

5050
private static final MockWebServer webServer = new MockWebServer();
5151
private static String gatewayUrl;
52-
private static ElasticInferenceServiceAuthorizationResponseEntityTests.EisAuthorizationResponse chatCompletionResponse;
52+
private static String chatCompletionResponseBody;
5353

5454
private AuthorizationTaskExecutor authorizationTaskExecutor;
5555
private ModelRegistry modelRegistry;
@@ -79,9 +79,9 @@ public static void initClass() throws IOException {
7979

8080
webServer.start();
8181
gatewayUrl = getUrl(webServer);
82-
chatCompletionResponse = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse(
82+
chatCompletionResponseBody = ElasticInferenceServiceAuthorizationResponseEntityTests.getEisRainbowSprinklesAuthorizationResponse(
8383
gatewayUrl
84-
);
84+
).responseJson();
8585
}
8686

8787
@Before
@@ -144,7 +144,7 @@ public void testCreatesEisChatCompletionEndpoint() throws Exception {
144144
var eisEndpoints = getEisEndpoints(modelRegistry);
145145
assertThat(eisEndpoints, empty());
146146

147-
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponse.responseJson()));
147+
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(chatCompletionResponseBody));
148148
var listener = new TestPlainActionFuture<Void>();
149149
ccmService.get().storeConfiguration(new CCMModel(new SecureString("secret".toCharArray())), listener);
150150
listener.actionGet(TimeValue.THIRTY_SECONDS);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportGetInferenceServicesAction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ private void getServiceConfigurationsForServicesAndEis(
127127
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(() -> getEisAuthorization(authModelListener, eisSender));
128128
}).<List<InferenceServiceConfiguration>>andThen((configurationListener, authorizationModel) -> {
129129
var serviceConfigs = getServiceConfigurationsForServices(availableServices);
130+
serviceConfigs.sort(Comparator.comparing(InferenceServiceConfiguration::getService));
130131

131132
if (authorizationModel.isAuthorized() == false) {
132133
configurationListener.onResponse(serviceConfigs);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -712,12 +712,12 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener<
712712
}), timeout);
713713
}
714714

715-
public void storeModels(List<? extends Model> models, ActionListener<List<ModelStoreResponse>> listener, TimeValue timeout) {
715+
public void storeModels(List<Model> models, ActionListener<List<ModelStoreResponse>> listener, TimeValue timeout) {
716716
storeModels(models, true, listener, timeout);
717717
}
718718

719719
private void storeModels(
720-
List<? extends Model> models,
720+
List<Model> models,
721721
boolean updateClusterState,
722722
ActionListener<List<ModelStoreResponse>> listener,
723723
TimeValue timeout
@@ -745,7 +745,7 @@ private void storeModels(
745745
}
746746

747747
private ActionListener<BulkResponse> getStoreMultipleModelsListener(
748-
List<? extends Model> models,
748+
List<Model> models,
749749
boolean updateClusterState,
750750
ActionListener<List<ModelStoreResponse>> listener,
751751
TimeValue timeout
@@ -818,12 +818,12 @@ private ActionListener<BulkResponse> getStoreMultipleModelsListener(
818818

819819
private record StoreResponseWithIndexInfo(ModelStoreResponse modelStoreResponse, boolean modifiedIndex) {}
820820

821-
private record ResponseInfo(List<StoreResponseWithIndexInfo> responses, List<? extends Model> successfullyStoredModels) {}
821+
private record ResponseInfo(List<StoreResponseWithIndexInfo> responses, List<Model> successfullyStoredModels) {}
822822

823823
private static ResponseInfo getResponseInfo(
824824
BulkResponse bulkResponse,
825825
Map<String, String> docIdToInferenceId,
826-
Map<String, ? extends Model> inferenceIdToModel
826+
Map<String, Model> inferenceIdToModel
827827
) {
828828
var responses = new ArrayList<StoreResponseWithIndexInfo>();
829829
var successfullyStoredModels = new ArrayList<Model>();
@@ -909,15 +909,15 @@ private static ModelStoreResponse createModelStoreResponse(BulkItemResponse item
909909
}
910910
}
911911

912-
private static Model getModelFromMap(@Nullable String inferenceId, Map<String, ? extends Model> inferenceIdToModel) {
912+
private static Model getModelFromMap(@Nullable String inferenceId, Map<String, Model> inferenceIdToModel) {
913913
if (inferenceId != null) {
914914
return inferenceIdToModel.get(inferenceId);
915915
}
916916

917917
return null;
918918
}
919919

920-
private void updateClusterState(List<? extends Model> models, ActionListener<AcknowledgedResponse> listener, TimeValue timeout) {
920+
private void updateClusterState(List<Model> models, ActionListener<AcknowledgedResponse> listener, TimeValue timeout) {
921921
var inferenceIdsSet = models.stream().map(Model::getInferenceEntityId).collect(Collectors.toSet());
922922
var storeListener = listener.delegateResponse((delegate, exc) -> {
923923
logger.warn(format("Failed to add minimal service settings to cluster state for inference endpoints %s", inferenceIdsSet), exc);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ private void sendRequest(ActionListener<Void> listener) {
314314
.addListener(listener);
315315
}
316316

317-
private List<? extends Model> getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) {
317+
private List<Model> getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) {
318318
logger.debug("Received authorization response, {}", authModel);
319319

320320
var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES));
@@ -328,7 +328,7 @@ private List<? extends Model> getNewInferenceEndpointsToStore(ElasticInferenceSe
328328
return scopedAuthModel.getEndpoints(newEndpointIds);
329329
}
330330

331-
private void storePreconfiguredModels(List<? extends Model> newEndpoints, ActionListener<Void> listener) {
331+
private void storePreconfiguredModels(List<Model> newEndpoints, ActionListener<Void> listener) {
332332
if (newEndpoints.isEmpty()) {
333333
listener.onResponse(null);
334334
return;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationModel.java

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.Strings;
1313
import org.elasticsearch.inference.EmptySecretSettings;
1414
import org.elasticsearch.inference.EmptyTaskSettings;
15+
import org.elasticsearch.inference.Model;
1516
import org.elasticsearch.inference.SimilarityMeasure;
1617
import org.elasticsearch.inference.TaskType;
1718
import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder;
@@ -48,6 +49,7 @@ public class ElasticInferenceServiceAuthorizationModel {
4849
private static final String UNKNOWN_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unknown task type [{}], skipping";
4950
private static final String UNSUPPORTED_TASK_TYPE_LOG_MESSAGE = "Authorized endpoint id [{}] has unsupported task type [{}], skipping";
5051

52+
// public because it's used in tests outside the package
5153
public static ElasticInferenceServiceAuthorizationModel of(
5254
ElasticInferenceServiceAuthorizationResponseEntity responseEntity,
5355
String baseEisUrl
@@ -214,12 +216,6 @@ private static void validateFieldPresent(String field, Object fieldValue, TaskTy
214216
}
215217

216218
private static SimilarityMeasure getSimilarityMeasure(ElasticInferenceServiceAuthorizationResponseEntity.Configuration configuration) {
217-
validateFieldPresent(
218-
ElasticInferenceServiceAuthorizationResponseEntity.Configuration.SIMILARITY,
219-
configuration.similarity(),
220-
TaskType.TEXT_EMBEDDING
221-
);
222-
223219
return SimilarityMeasure.fromString(configuration.similarity());
224220
}
225221

@@ -292,8 +288,8 @@ public Set<String> getEndpointIds() {
292288
return Set.copyOf(authorizedEndpoints.keySet());
293289
}
294290

295-
public List<ElasticInferenceServiceModel> getEndpoints(Set<String> endpointIds) {
296-
return endpointIds.stream().map(authorizedEndpoints::get).filter(Objects::nonNull).toList();
291+
public List<Model> getEndpoints(Set<String> endpointIds) {
292+
return endpointIds.stream().<Model>map(authorizedEndpoints::get).filter(Objects::nonNull).toList();
297293
}
298294

299295
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,8 @@ public String toString() {
301301

302302
private final List<AuthorizedEndpoint> authorizedEndpoints;
303303

304-
public ElasticInferenceServiceAuthorizationResponseEntity(List<AuthorizedEndpoint> authorizedModels) {
305-
this.authorizedEndpoints = Objects.requireNonNull(authorizedModels);
304+
public ElasticInferenceServiceAuthorizationResponseEntity(List<AuthorizedEndpoint> authorizedEndpoints) {
305+
this.authorizedEndpoints = Objects.requireNonNull(authorizedEndpoints);
306306
}
307307

308308
/**

0 commit comments

Comments
 (0)