diff --git a/docs/changelog/136713.yaml b/docs/changelog/136713.yaml new file mode 100644 index 0000000000000..45dedd222f07e --- /dev/null +++ b/docs/changelog/136713.yaml @@ -0,0 +1,5 @@ +pr: 136713 +summary: Transition EIS auth polling to persistent task on a single node +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/inference/Model.java b/server/src/main/java/org/elasticsearch/inference/Model.java index 87744fbd09574..369783463cd86 100644 --- a/server/src/main/java/org/elasticsearch/inference/Model.java +++ b/server/src/main/java/org/elasticsearch/inference/Model.java @@ -9,9 +9,14 @@ package org.elasticsearch.inference; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; + +import java.io.IOException; import java.util.Objects; -public class Model { +public class Model implements Writeable { public static String documentId(String modelId) { return "model_" + modelId; } @@ -42,6 +47,11 @@ public Model(ModelConfigurations configurations) { this(configurations, new ModelSecrets()); } + public Model(StreamInput in) throws IOException { + this.configurations = new ModelConfigurations(in); + this.secrets = new ModelSecrets(in); + } + public String getInferenceEntityId() { return configurations.getInferenceEntityId(); } @@ -111,4 +121,10 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(configurations, secrets); } + + @Override + public void writeTo(StreamOutput out) throws IOException { + configurations.writeTo(out); + secrets.writeTo(out); + } } diff --git a/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv new file mode 100644 index 0000000000000..d8a9923c803c9 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/inference_api_eis_authorization_persistent_task.csv @@ -0,0 +1 @@ +9215000 diff --git a/server/src/main/resources/transport/upper_bounds/9.3.csv b/server/src/main/resources/transport/upper_bounds/9.3.csv index 781de9c6e1a78..6b0edb76f268f 100644 --- a/server/src/main/resources/transport/upper_bounds/9.3.csv +++ b/server/src/main/resources/transport/upper_bounds/9.3.csv @@ -1 +1 @@ -transform_preview_as_index_request,9214000 +inference_api_eis_authorization_persistent_task,9215000 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java new file mode 100644 index 0000000000000..aa613cda60399 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsAction.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +/** + * Internal action to store inference endpoints and return the results of the store operation. This should only be used internally and not + * exposed via a REST API. + * For the exposed REST API action see {@link PutInferenceModelAction}. + */ +public class StoreInferenceEndpointsAction extends ActionType { + + public static final StoreInferenceEndpointsAction INSTANCE = new StoreInferenceEndpointsAction(); + public static final String NAME = "cluster:internal/xpack/inference/create_endpoints"; + + public StoreInferenceEndpointsAction() { + super(NAME); + } + + public static class Request extends AcknowledgedRequest { + private final List models; + + public Request(List models, TimeValue timeout) { + super(timeout, DEFAULT_ACK_TIMEOUT); + this.models = Objects.requireNonNull(models); + } + + public Request(StreamInput in) throws IOException { + super(in); + models = in.readCollectionAsImmutableList(Model::new); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeCollection(models); + } + + public List getModels() { + return models; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(models, request.models); + } + + @Override + public int hashCode() { + return Objects.hashCode(models); + } + } + + public static class Response extends ActionResponse { + private final List results; + + public Response(List results) { + this.results = results; + } + + public Response(StreamInput in) throws IOException { + results = in.readCollectionAsImmutableList(ModelStoreResponse::new); + } + + public List getResults() { + return results; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(results); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Response response = (Response) o; + return Objects.equals(results, response.results); + } + + @Override + public int hashCode() { + return Objects.hash(results); + } + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponse.java new file mode 100644 index 0000000000000..30cbdfdfb96cd --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponse.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; + +import java.io.IOException; +import java.util.Objects; + +/** + * Response for storing a model in the model registry using the bulk API. + */ +public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) implements Writeable { + + public ModelStoreResponse(StreamInput in) throws IOException { + this(in.readString(), RestStatus.readFrom(in), in.readException()); + } + + public boolean failed() { + return failureCause != null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(inferenceId); + RestStatus.writeTo(out, status); + out.writeException(failureCause); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + ModelStoreResponse that = (ModelStoreResponse) o; + return status == that.status && Objects.equals(inferenceId, that.inferenceId) + // Exception does not have hashCode() or equals() so assume errors are equal iff class and message are equal + && Objects.equals( + failureCause == null ? null : failureCause.getMessage(), + that.failureCause == null ? null : that.failureCause.getMessage() + ) + && Objects.equals( + failureCause == null ? null : failureCause.getClass(), + that.failureCause == null ? null : that.failureCause.getClass() + ); + } + + @Override + public int hashCode() { + return Objects.hash( + inferenceId, + status, + // Exception does not have hashCode() or equals() so assume errors are equal iff class and message are equal + failureCause == null ? null : failureCause.getMessage(), + failureCause == null ? null : failureCause.getClass() + ); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java new file mode 100644 index 0000000000000..05a9227ae278e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/ModelTests.java @@ -0,0 +1,241 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.XPackClientPlugin; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +public class ModelTests extends AbstractBWCWireSerializationTestCase { + public static Model randomModel() { + return new Model( + new ModelConfigurations( + randomAlphaOfLength(6), + randomFrom(TaskType.values()), + randomAlphaOfLength(6), + new TestServiceSettings( + randomAlphaOfLength(10), + randomIntBetween(1, 1024), + randomFrom(SimilarityMeasure.values()), + randomFrom(DenseVectorFieldMapper.ElementType.values()) + ), + EmptyTaskSettings.INSTANCE, + randomBoolean() ? ChunkingSettingsTests.createRandomChunkingSettings() : null + ), + new ModelSecrets(EmptySecretSettings.INSTANCE) + ); + } + + public record TestServiceSettings( + String model, + Integer dimensions, + @Nullable SimilarityMeasure similarity, + @Nullable DenseVectorFieldMapper.ElementType elementType + ) implements ServiceSettings { + + static final String NAME = "test_text_embedding_service_settings"; + + public TestServiceSettings(StreamInput in) throws IOException { + this( + in.readString(), + in.readInt(), + in.readOptionalEnum(SimilarityMeasure.class), + in.readOptionalEnum(DenseVectorFieldMapper.ElementType.class) + ); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("model", model); + builder.field("dimensions", dimensions); + if (similarity != null) { + builder.field("similarity", similarity); + } + if (elementType != null) { + builder.field("element_type", elementType); + } + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(model); + out.writeInt(dimensions); + out.writeOptionalEnum(similarity); + out.writeOptionalEnum(elementType); + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return this; + } + + @Override + public SimilarityMeasure similarity() { + return similarity != null ? similarity : SimilarityMeasure.COSINE; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return elementType != null ? elementType : DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public String modelId() { + return model; + } + } + + public record SimpleSecretSettings(String field) implements SecretSettings { + public static final String NAME = "simple_secret_settings"; + private static final String FIELD_KEY = "field"; + + public SimpleSecretSettings { + Objects.requireNonNull(field); + } + + public SimpleSecretSettings(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); // fine for these tests but will not work for cluster upgrade tests + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(field); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_KEY, field); + builder.endObject(); + return builder; + } + + @Override + public SecretSettings newSecretSettings(Map newSecrets) { + if (newSecrets == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + var value = newSecrets.get(FIELD_KEY); + if (value == null) { + validationException.addValidationError("Missing required secret setting: " + FIELD_KEY); + throw validationException; + } else if (value instanceof String == false) { + validationException.addValidationError("Expected secret setting [" + FIELD_KEY + "] to be of type String"); + throw validationException; + } + return new SimpleSecretSettings((String) value); + } + } + + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(ServiceSettings.class, TestServiceSettings.NAME, TestServiceSettings::new), + new NamedWriteableRegistry.Entry(SecretSettings.class, SimpleSecretSettings.NAME, SimpleSecretSettings::new) + ); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + var namedWriteables = new ArrayList(); + namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new)); + namedWriteables.addAll(getNamedWriteables()); + namedWriteables.addAll(XPackClientPlugin.getChunkingSettingsNamedWriteables()); + + return new NamedWriteableRegistry(namedWriteables); + } + + @Override + protected Model mutateInstanceForVersion(Model instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return Model::new; + } + + @Override + protected Model createTestInstance() { + return randomModel(); + } + + @Override + protected Model mutateInstance(Model instance) throws IOException { + int choice = randomIntBetween(0, 1); + switch (choice) { + case 0 -> { + var originalConfig = instance.getConfigurations(); + ModelConfigurations mutatedConfig = new ModelConfigurations( + originalConfig.getInferenceEntityId() + "_mutated", + originalConfig.getTaskType(), + originalConfig.getService(), + originalConfig.getServiceSettings(), + originalConfig.getTaskSettings(), + originalConfig.getChunkingSettings() + ); + return new Model(mutatedConfig, instance.getSecrets()); + } + case 1 -> { + return new Model(instance.getConfigurations(), new ModelSecrets(new SimpleSecretSettings(randomAlphaOfLength(10)))); + } + default -> throw new IllegalStateException("Unexpected value: " + choice); + } + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java new file mode 100644 index 0000000000000..3673296c29ce7 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionRequestTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.core.XPackClientPlugin; +import org.elasticsearch.xpack.core.inference.ModelTests; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.ArrayList; + +public class StoreInferenceEndpointsActionRequestTests extends AbstractBWCWireSerializationTestCase { + + @Override + protected StoreInferenceEndpointsAction.Request mutateInstanceForVersion( + StoreInferenceEndpointsAction.Request instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return StoreInferenceEndpointsAction.Request::new; + } + + @Override + protected StoreInferenceEndpointsAction.Request createTestInstance() { + return new StoreInferenceEndpointsAction.Request(randomList(5, ModelTests::randomModel), randomTimeValue()); + } + + @Override + protected StoreInferenceEndpointsAction.Request mutateInstance(StoreInferenceEndpointsAction.Request instance) throws IOException { + var newModels = new ArrayList<>(instance.getModels()); + newModels.add(ModelTests.randomModel()); + return new StoreInferenceEndpointsAction.Request(newModels, instance.masterNodeTimeout()); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + var namedWriteables = new ArrayList(); + namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, EmptyTaskSettings.NAME, EmptyTaskSettings::new)); + namedWriteables.add(new NamedWriteableRegistry.Entry(SecretSettings.class, EmptySecretSettings.NAME, EmptySecretSettings::new)); + namedWriteables.addAll(ModelTests.getNamedWriteables()); + namedWriteables.addAll(XPackClientPlugin.getChunkingSettingsNamedWriteables()); + + return new NamedWriteableRegistry(namedWriteables); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionResponseTests.java new file mode 100644 index 0000000000000..c7a692f0c5e9e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/StoreInferenceEndpointsActionResponseTests.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponseTests; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.ArrayList; + +public class StoreInferenceEndpointsActionResponseTests extends AbstractBWCWireSerializationTestCase< + StoreInferenceEndpointsAction.Response> { + + @Override + protected StoreInferenceEndpointsAction.Response mutateInstanceForVersion( + StoreInferenceEndpointsAction.Response instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return StoreInferenceEndpointsAction.Response::new; + } + + @Override + protected StoreInferenceEndpointsAction.Response createTestInstance() { + return new StoreInferenceEndpointsAction.Response(randomList(5, ModelStoreResponseTests::randomModelStoreResponse)); + } + + @Override + protected StoreInferenceEndpointsAction.Response mutateInstance(StoreInferenceEndpointsAction.Response instance) throws IOException { + var newResults = new ArrayList<>(instance.getResults()); + newResults.add(ModelStoreResponseTests.randomModelStoreResponse()); + return new StoreInferenceEndpointsAction.Response(newResults); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java new file mode 100644 index 0000000000000..e7160fc360bb3 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/ModelStoreResponseTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +public class ModelStoreResponseTests extends AbstractBWCWireSerializationTestCase { + + public static ModelStoreResponse randomModelStoreResponse() { + return new ModelStoreResponse( + randomAlphaOfLength(10), + randomFrom(RestStatus.values()), + randomBoolean() ? null : new IllegalStateException("Test exception") + ); + } + + public void testFailed() { + { + var successResponse = new ModelStoreResponse("model_1", RestStatus.OK, null); + assertFalse(successResponse.failed()); + } + { + var failedResponse = new ModelStoreResponse( + "model_2", + RestStatus.INTERNAL_SERVER_ERROR, + new IllegalStateException("Test failure") + ); + assertTrue(failedResponse.failed()); + } + { + var failedResponse = new ModelStoreResponse("model_2", RestStatus.OK, new IllegalStateException("Test failure")); + assertTrue(failedResponse.failed()); + } + } + + @Override + protected ModelStoreResponse mutateInstanceForVersion(ModelStoreResponse instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return ModelStoreResponse::new; + } + + @Override + protected ModelStoreResponse createTestInstance() { + return randomModelStoreResponse(); + } + + @Override + protected ModelStoreResponse mutateInstance(ModelStoreResponse instance) throws IOException { + int choice = randomIntBetween(0, 2); + return switch (choice) { + case 0 -> { + String newInferenceId = instance.inferenceId() + "_mutated"; + yield new ModelStoreResponse(newInferenceId, instance.status(), instance.failureCause()); + } + case 1 -> new ModelStoreResponse( + instance.inferenceId(), + randomValueOtherThan(instance.status(), () -> randomFrom(RestStatus.values())), + instance.failureCause() + ); + case 2 -> { + Exception newFailureCause = instance.failureCause() == null ? new IllegalStateException("Mutated exception") : null; + yield new ModelStoreResponse(instance.inferenceId(), instance.status(), newFailureCause); + } + default -> throw new IllegalStateException("Unexpected value: " + choice); + }; + } +} diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java index 63d326016a5ec..bc403a78bc2a8 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/BaseMockEISAuthServerTest.java @@ -17,11 +17,15 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.junit.Before; import org.junit.ClassRule; import org.junit.Rule; import org.junit.rules.RuleChain; import org.junit.rules.TestRule; +import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.getModel; + public class BaseMockEISAuthServerTest extends ESRestTestCase { protected static final MockElasticInferenceServiceAuthorizationServer mockEISServer = @@ -71,4 +75,20 @@ protected Settings restClientSettings() { String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray())); return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build(); } + + @Override + protected boolean preserveClusterUponCompletion() { + // Keep the cluster around so the EIS preconfigured endpoints still exist between tests. Otherwise, the inference indices will + // be removed when the cluster is wiped which causes the tests after the first one to fail. + return true; + } + + @Before + public void ensureEisPreconfiguredEndpointsExist() throws Exception { + // Ensure that the authorization logic has completed prior to running each test so we have the correct EIS preconfigured endpoints + // available + // Technically this only needs to be done before the suite runs but the underlying client is created in @Before and not statically + // for the suite + assertBusy(() -> getModel(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2)); + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 2c833186df0f0..c7373e751eb7f 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -347,7 +347,7 @@ protected Map deployE5TrainedModels() throws IOException { } @SuppressWarnings("unchecked") - protected Map getModel(String modelId) throws IOException { + static Map getModel(String modelId) throws IOException { var endpoint = Strings.format("_inference/%s?error_trace", modelId); return ((List>) getInternalAsMap(endpoint).get("endpoints")).get(0); } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java new file mode 100644 index 0000000000000..8450ceab04848 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorIT.java @@ -0,0 +1,271 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequestBuilder; +import org.elasticsearch.action.admin.cluster.node.tasks.list.ListTasksRequestBuilder; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.AdminClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.reindex.ReindexPlugin; +import org.elasticsearch.tasks.TaskInfo; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; + +public class AuthorizationTaskExecutorIT extends ESSingleNodeTestCase { + public static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; + + public static final String EMPTY_AUTH_RESPONSE = """ + { + "models": [ + ] + } + """; + + public static final String AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE = """ + { + "models": [ + { + "model_name": "rainbow-sprinkles", + "task_types": ["chat"] + } + ] + } + """; + + private static final MockWebServer webServer = new MockWebServer(); + private static String gatewayUrl; + + private ModelRegistry modelRegistry; + private AuthorizationTaskExecutor authorizationTaskExecutor; + + @BeforeClass + public static void initClass() throws IOException { + webServer.start(); + gatewayUrl = getUrl(webServer); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + } + + @Before + public void createComponents() { + modelRegistry = node().injector().getInstance(ModelRegistry.class); + authorizationTaskExecutor = node().injector().getInstance(AuthorizationTaskExecutor.class); + } + + @After + public void shutdown() { + // Delete all the eis preconfigured endpoints + var listener = new PlainActionFuture(); + modelRegistry.deleteModels(InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS, listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); + } + + @AfterClass + public static void cleanUpClass() { + webServer.close(); + } + + @Override + protected Settings nodeSettings() { + return Settings.builder() + .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) + // Ensure that the polling logic only occurs once so we can deterministically control when an authorization response is + // received + .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false) + .build(); + } + + @Override + protected Collection> getPlugins() { + return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); + } + + public void testCreatesEisChatCompletionEndpoint() throws Exception { + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + restartPollingTaskAndWaitForAuthResponse(); + + assertChatCompletionEndpointExists(); + } + + private void assertNoAuthorizedEisEndpoints() throws Exception { + waitForTask(AUTH_TASK_ACTION, admin()); + + assertBusy(() -> { + var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); + assertNotNull(newPoller); + newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS); + }); + + var eisEndpoints = getEisEndpoints(); + assertThat(eisEndpoints, empty()); + + for (String eisPreconfiguredEndpoints : InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS) { + assertFalse(modelRegistry.containsPreconfiguredInferenceEndpointId(eisPreconfiguredEndpoints)); + } + } + + public static TaskInfo waitForTask(String taskAction, AdminClient adminClient) throws Exception { + var taskRef = new AtomicReference(); + var builder = new ListTasksRequestBuilder(adminClient.cluster()); + + assertBusy(() -> { + var response = builder.get(); + var authPollerTask = response.getTasks().stream().filter(task -> task.action().equals(taskAction)).findFirst(); + assertTrue(authPollerTask.isPresent()); + taskRef.set(authPollerTask.get()); + }); + + return taskRef.get(); + } + + private List getEisEndpoints() { + var listener = new PlainActionFuture>(); + modelRegistry.getAllModels(false, listener); + + var endpoints = listener.actionGet(TimeValue.THIRTY_SECONDS); + return endpoints.stream().filter(m -> m.service().equals(ElasticInferenceService.NAME)).toList(); + } + + private void restartPollingTaskAndWaitForAuthResponse() throws Exception { + cancelAuthorizationTask(admin()); + + // wait for the new task to be recreated and an authorization response to be processed + assertBusy(() -> { + var newPoller = authorizationTaskExecutor.getCurrentPollerTask(); + assertNotNull(newPoller); + newPoller.waitForAuthorizationToComplete(TimeValue.THIRTY_SECONDS); + }); + } + + public static void cancelAuthorizationTask(AdminClient adminClient) throws Exception { + var pollerTask = waitForTask(AUTH_TASK_ACTION, adminClient); + var builder = new CancelTasksRequestBuilder(adminClient.cluster()); + + assertBusy(() -> { + var cancelTaskResponse = builder.setActions(AUTH_TASK_ACTION).get(); + assertThat(cancelTaskResponse.getTasks().size(), is(1)); + assertThat(cancelTaskResponse.getTasks().get(0).action(), is(AUTH_TASK_ACTION)); + }); + + var newPollerTask = waitForTask(AUTH_TASK_ACTION, adminClient); + assertThat(newPollerTask.taskId(), is(not(pollerTask.taskId()))); + } + + public void testCreatesEisChatCompletion_DoesNotRemoveEndpointWhenNoLongerAuthorized() throws Exception { + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + restartPollingTaskAndWaitForAuthResponse(); + + assertChatCompletionEndpointExists(); + + // Simulate that the model is no longer authorized + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + restartPollingTaskAndWaitForAuthResponse(); + + assertChatCompletionEndpointExists(); + } + + private void assertChatCompletionEndpointExists() { + var eisEndpoints = getEisEndpoints(); + assertThat(eisEndpoints.size(), is(1)); + + var rainbowSprinklesModel = eisEndpoints.get(0); + assertChatCompletionUnparsedModel(rainbowSprinklesModel); + assertTrue( + modelRegistry.containsPreconfiguredInferenceEndpointId(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1) + ); + } + + private void assertChatCompletionUnparsedModel(UnparsedModel rainbowSprinklesModel) { + assertThat(rainbowSprinklesModel.taskType(), is(TaskType.CHAT_COMPLETION)); + assertThat(rainbowSprinklesModel.service(), is(ElasticInferenceService.NAME)); + assertThat(rainbowSprinklesModel.inferenceEntityId(), is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + } + + public void testCreatesChatCompletion_AndThenCreatesTextEmbedding() throws Exception { + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + restartPollingTaskAndWaitForAuthResponse(); + + assertChatCompletionEndpointExists(); + + // Simulate that the model is no longer authorized + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + restartPollingTaskAndWaitForAuthResponse(); + + assertChatCompletionEndpointExists(); + + // Simulate that a text embedding model is now authorized + var authorizedTextEmbeddingResponse = """ + { + "models": [ + { + "model_name": "jina-embeddings-v3", + "task_types": ["embed/text/dense"] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(authorizedTextEmbeddingResponse)); + restartPollingTaskAndWaitForAuthResponse(); + + var eisEndpoints = getEisEndpoints().stream().collect(Collectors.toMap(UnparsedModel::inferenceEntityId, Function.identity())); + assertThat(eisEndpoints.size(), is(2)); + + assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + assertChatCompletionUnparsedModel(eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1)); + + assertTrue(eisEndpoints.containsKey(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID)); + + var textEmbeddingEndpoint = eisEndpoints.get(InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID); + assertThat(textEmbeddingEndpoint.taskType(), is(TaskType.TEXT_EMBEDDING)); + assertThat(textEmbeddingEndpoint.service(), is(ElasticInferenceService.NAME)); + } + + public void testRestartsTaskAfterAbort() throws Exception { + // Ensure the task is created and we get an initial authorization response + assertNoAuthorizedEisEndpoints(); + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + // Abort the task and ensure it is restarted + restartPollingTaskAndWaitForAuthResponse(); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java new file mode 100644 index 0000000000000..cb92c70d27442 --- /dev/null +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/AuthorizationTaskExecutorMultipleNodesIT.java @@ -0,0 +1,182 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.integration; + +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.license.LicenseSettings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.EMPTY_AUTH_RESPONSE; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.cancelAuthorizationTask; +import static org.elasticsearch.xpack.inference.integration.AuthorizationTaskExecutorIT.waitForTask; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.not; + +/** + * These tests handle testing task relocation and cancellation. + * If the task is running on a node that is shutdown, it should be relocated to another node. + * If the task is cancelled it should be restarted automatically. + */ +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0) +public class AuthorizationTaskExecutorMultipleNodesIT extends ESIntegTestCase { + + private static final int NUM_DATA_NODES = 2; + private static final int NUM_MASTER_NODES = 2; + private static final String AUTH_TASK_ACTION = AuthorizationPoller.TASK_NAME + "[c]"; + private static final MockWebServer webServer = new MockWebServer(); + private static String gatewayUrl; + + @BeforeClass + public static void initClass() throws IOException { + webServer.start(); + gatewayUrl = getUrl(webServer); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(EMPTY_AUTH_RESPONSE)); + } + + @Before + public void startNodes() { + // Ensure we have multiple master and data nodes so we have somewhere to place the inference indices and so that we can safely + // shut down the node that is running the authorization task. If there is only one master and it is running the task, + // we'll get an error that we can't shut down the only eligible master node + internalCluster().startMasterOnlyNodes(NUM_MASTER_NODES); + internalCluster().ensureAtLeastNumDataNodes(NUM_DATA_NODES); + ensureStableCluster(NUM_MASTER_NODES + NUM_DATA_NODES); + } + + @AfterClass + public static void cleanUpClass() { + webServer.close(); + } + + @Override + protected Collection> nodePlugins() { + return List.of(LocalStateInferencePlugin.class); + } + + @Override + protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal, otherSettings)) + .put(LicenseSettings.SELF_GENERATED_LICENSE_TYPE.getKey(), "trial") + .put(ElasticInferenceServiceSettings.ELASTIC_INFERENCE_SERVICE_URL.getKey(), gatewayUrl) + .put(ElasticInferenceServiceSettings.PERIODIC_AUTHORIZATION_ENABLED.getKey(), false) + .build(); + } + + public void testCancellingAuthorizationTaskRestartsIt() throws Exception { + cancelAuthorizationTask(admin()); + } + + public void testAuthorizationTaskGetsRelocatedToAnotherNode_WhenTheNodeThatIsRunningItShutsDown() throws Exception { + var nodeNameMapping = getNodeNames(internalCluster().getNodeNames()); + + var pollerTask = waitForTask(AUTH_TASK_ACTION, admin()); + + var endpoints = getAllEndpoints(); + assertTrue( + "expected no authorized EIS endpoints", + endpoints.getEndpoints().stream().noneMatch(endpoint -> endpoint.getService().equals(ElasticInferenceService.NAME)) + ); + + // queue a response that authorizes one model + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(AUTHORIZED_RAINBOW_SPRINKLES_RESPONSE)); + + assertTrue("expected the node to shutdown properly", internalCluster().stopNode(nodeNameMapping.get(pollerTask.node()))); + + assertBusy(() -> { + var relocatedPollerTask = waitForTask(AUTH_TASK_ACTION, admin()); + assertThat(relocatedPollerTask.node(), not(is(pollerTask.node()))); + }); + + assertBusy(() -> { + var allEndpoints = getAllEndpoints(); + + var eisEndpoints = allEndpoints.getEndpoints() + .stream() + .filter(endpoint -> endpoint.getService().equals(ElasticInferenceService.NAME)) + .toList(); + assertThat(eisEndpoints.size(), is(1)); + + var rainbowSprinklesEndpoint = eisEndpoints.get(0); + assertThat(rainbowSprinklesEndpoint.getService(), is(ElasticInferenceService.NAME)); + assertThat( + rainbowSprinklesEndpoint.getInferenceEntityId(), + is(InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1) + ); + assertThat(rainbowSprinklesEndpoint.getTaskType(), is(TaskType.CHAT_COMPLETION)); + }); + } + + private record NodeNameMapping(Map nodeNamesMap) { + public String get(String rawNodeName) { + var nodeName = nodeNamesMap.get(rawNodeName); + if (nodeName == null) { + throw new IllegalArgumentException("No node name found for raw node name: " + rawNodeName); + } + + return nodeName; + } + } + + /** + * The node names created by the integration test framework take the form of "node_#", but the task api gives a raw node name + * like 02PT2SBzRxC3cG-9mKCigQ, so we need to map between them to be able to act on a node that the task is currently running on. + */ + private static NodeNameMapping getNodeNames(String[] nodes) { + var nodeNamesMap = new HashMap(); + for (var node : nodes) { + var nodeTasks = admin().cluster().prepareListTasks(node).get(); + assertThat(nodeTasks.getTasks().size(), greaterThanOrEqualTo(1)); + nodeNamesMap.put(nodeTasks.getTasks().getFirst().node(), node); + } + + return new NodeNameMapping(nodeNamesMap); + } + + private GetInferenceModelAction.Response getAllEndpoints() throws Exception { + var getAllEndpointsRequest = new GetInferenceModelAction.Request("*", TaskType.ANY, true); + + var allEndpointsRef = new AtomicReference(); + assertBusy(() -> { + try { + allEndpointsRef.set( + internalCluster().masterClient().execute(GetInferenceModelAction.INSTANCE, getAllEndpointsRequest).actionGet() + ); + } catch (Exception e) { + // We probably got an all shards failed exception because the indices aren't ready yet, we'll just try again + logger.warn("Failed to retrieve endpoints", e); + fail("Failed to retrieve endpoints"); + } + }); + + return allEndpointsRef.get(); + } +} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java deleted file mode 100644 index 8c4a2b1b2504c..0000000000000 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ /dev/null @@ -1,357 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.integration; - -import org.elasticsearch.ResourceNotFoundException; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.reindex.ReindexPlugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; -import org.elasticsearch.test.http.MockWebServer; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; -import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; -import org.junit.After; -import org.junit.Before; - -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.concurrent.TimeUnit; - -import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; -import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.mockito.Mockito.mock; - -@ESTestCase.WithoutEntitlements // due to dependency issue ES-12435 -public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase { - private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); - - private ModelRegistry modelRegistry; - private final MockWebServer webServer = new MockWebServer(); - private ThreadPool threadPool; - private String gatewayUrl; - - @Before - public void createComponents() throws Exception { - threadPool = createThreadPool(inferenceUtilityExecutors()); - webServer.start(); - gatewayUrl = getUrl(webServer); - modelRegistry = node().injector().getInstance(ModelRegistry.class); - } - - @After - public void shutdown() { - terminate(threadPool); - webServer.close(); - } - - @Override - protected boolean resetNodeAfterTest() { - return true; - } - - @Override - protected Collection> getPlugins() { - return pluginList(ReindexPlugin.class, LocalStateInferencePlugin.class); - } - - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationReturnsEmpty() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnAuthForIt() throws Exception { - { - String responseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "elastic-rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".jina-embeddings-v3", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".elastic-rerank-v1", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elastic-rerank-v1")); - assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".jina-embeddings-v3")); - assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - - var getModelListener = new PlainActionFuture(); - // persists the default endpoints - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - - var inferenceEntity = getModelListener.actionGet(TIMEOUT); - assertThat(inferenceEntity.inferenceEntityId(), is(".rainbow-sprinkles-elastic")); - assertThat(inferenceEntity.taskType(), is(TaskType.CHAT_COMPLETION)); - } - } - { - String noAuthorizationResponseJson = """ - { - "models": [ - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "elastic-rerank-v1", - "task_types": ["rerank/text/text-similarity"] - }, - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(noAuthorizationResponseJson)); - - try (var service = createElasticInferenceService()) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - containsInAnyOrder( - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".jina-embeddings-v3", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".elastic-rerank-v1", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) - ); - - var getModelListener = new PlainActionFuture(); - modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); - var exception = expectThrows(ResourceNotFoundException.class, () -> getModelListener.actionGet(TIMEOUT)); - assertThat(exception.getMessage(), is("Inference endpoint not found [.rainbow-sprinkles-elastic]")); - } - } - } - - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - - private ElasticInferenceService createElasticInferenceService() { - var httpManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, httpManager); - - return new ElasticInferenceService( - senderFactory, - createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(gatewayUrl), - modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(gatewayUrl, threadPool), - mockClusterServiceEmpty() - ); - } -} diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java index 4f53559dbba02..708edc4279148 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java @@ -49,6 +49,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsTests; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; @@ -601,7 +602,7 @@ public void testGetUnparsedModelMap_ThrowsResourceNotFound_WhenNoHitsReturned() } public void testStoreModels_ReturnsEmptyList_WhenGivenNoModelsToStore() { - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); @@ -621,12 +622,12 @@ public void testStoreModels_StoresSingleInferenceEndpoint() { new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(1)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertMinimalServiceSettings(modelRegistry, model); @@ -660,13 +661,13 @@ public void testStoreModels_StoresMultipleInferenceEndpoints() { new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(2)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); - assertThat(response.get(1), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId1, RestStatus.CREATED, null))); + assertThat(response.get(1), Matchers.is(new ModelStoreResponse(inferenceId2, RestStatus.CREATED, null))); assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); assertModelAndMinimalSettingsWithSecrets(modelRegistry, model2, secrets); @@ -717,12 +718,12 @@ public void testStoreModels_StoresOneModel_FailsToStoreSecond_WhenVersionConflic new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model2), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(2)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertThat(response.get(1).inferenceId(), Matchers.is(model2.getInferenceEntityId())); assertThat(response.get(1).status(), Matchers.is(RestStatus.CONFLICT)); assertTrue(response.get(1).failed()); @@ -759,12 +760,12 @@ public void testStoreModels_StoresOneModel_RemovesSecondDuplicateModelFromList_D new TestModel.TestSecretSettings(secrets) ); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model1, model2), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); assertThat(response.size(), Matchers.is(1)); - assertThat(response.get(0), Matchers.is(new ModelRegistry.ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); + assertThat(response.get(0), Matchers.is(new ModelStoreResponse(inferenceId, RestStatus.CREATED, null))); assertModelAndMinimalSettingsWithSecrets(modelRegistry, model1, secrets); assertIndicesContainExpectedDocsCount(model1, 2); @@ -784,7 +785,7 @@ public void testStoreModels_FailsToStoreModel_WhenInferenceIndexDocumentAlreadyE storeCorruptedModel(model, false); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); @@ -838,7 +839,7 @@ public void testStoreModels_OnFailure_RemovesPartialWritesOfInferenceEndpoint() storeCorruptedModel(model1, false); storeCorruptedModel(model2, true); - PlainActionFuture> storeListener = new PlainActionFuture<>(); + PlainActionFuture> storeListener = new PlainActionFuture<>(); modelRegistry.storeModels(List.of(model1, model2, model3), storeListener, TimeValue.THIRTY_SECONDS); var response = storeListener.actionGet(TimeValue.THIRTY_SECONDS); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 151f42fbfb568..9e1f17643ac73 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -9,16 +9,19 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.support.MappedActionFilter; +import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.NamedDiff; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.IndexScopedSettings; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.SettingsFilter; +import org.elasticsearch.common.settings.SettingsModule; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.TimeValue; @@ -34,10 +37,12 @@ import org.elasticsearch.license.LicensedFeature; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.node.PluginComponentBinding; +import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.ClusterPlugin; import org.elasticsearch.plugins.ExtensiblePlugin; import org.elasticsearch.plugins.MapperPlugin; +import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.plugins.SystemIndexPlugin; @@ -68,6 +73,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceActionProxy; import org.elasticsearch.xpack.core.inference.action.PutCCMConfigurationAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.core.ssl.SSLService; @@ -83,6 +89,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutCCMConfigurationAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportStoreEndpointsAction; import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; @@ -139,6 +146,8 @@ import org.elasticsearch.xpack.inference.services.deepseek.DeepSeekService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller; +import org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationTaskExecutor; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMFeature; import org.elasticsearch.xpack.inference.services.elastic.ccm.CCMIndex; @@ -171,6 +180,7 @@ import java.util.Set; import java.util.function.Predicate; import java.util.function.Supplier; +import java.util.stream.Stream; import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; @@ -183,7 +193,8 @@ public class InferencePlugin extends Plugin MapperPlugin, SearchPlugin, InternalSearchPlugin, - ClusterPlugin { + ClusterPlugin, + PersistentTaskPlugin { /** * When this setting is true the verification check that @@ -230,7 +241,7 @@ public class InferencePlugin extends Plugin private final Settings settings; private final SetOnce httpFactory = new SetOnce<>(); private final SetOnce amazonBedrockFactory = new SetOnce<>(); - private final SetOnce elasicInferenceServiceFactory = new SetOnce<>(); + private final SetOnce elasticInferenceServiceFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); // This is mainly so that the rest handlers can access the ThreadPool in a way that avoids potential null pointers from it // not being initialized yet @@ -240,6 +251,7 @@ public class InferencePlugin extends Plugin private final SetOnce modelRegistry = new SetOnce<>(); private final SetOnce ccmFeature = new SetOnce<>(); private List inferenceServiceExtensions; + private final SetOnce authorizationTaskExecutorRef = new SetOnce<>(); public InferencePlugin(Settings settings) { this.settings = settings; @@ -260,6 +272,7 @@ public List getActions() { new ActionHandler(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class), new ActionHandler(GetRerankerWindowSizeAction.INSTANCE, TransportGetRerankerWindowSizeAction.class), new ActionHandler(ClearInferenceEndpointCacheAction.INSTANCE, ClearInferenceEndpointCacheAction.class), + new ActionHandler(StoreInferenceEndpointsAction.INSTANCE, TransportStoreEndpointsAction.class), new ActionHandler(GetCCMConfigurationAction.INSTANCE, TransportGetCCMConfigurationAction.class), new ActionHandler(PutCCMConfigurationAction.INSTANCE, TransportPutCCMConfigurationAction.class), new ActionHandler(DeleteCCMConfigurationAction.INSTANCE, TransportDeleteCCMConfigurationAction.class) @@ -337,23 +350,34 @@ public Collection createComponents(PluginServices services) { elasticInferenceServiceHttpClientManager, services.clusterService() ); - elasicInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); + elasticInferenceServiceFactory.set(elasticInferenceServiceRequestSenderFactory); var authorizationHandler = new ElasticInferenceServiceAuthorizationRequestHandler( inferenceServiceSettings.getElasticInferenceServiceUrl(), services.threadPool() ); + var authTaskExecutor = AuthorizationTaskExecutor.create( + services.clusterService(), + new AuthorizationPoller.Parameters( + serviceComponents.get(), + authorizationHandler, + elasticInferenceServiceFactory.get().createSender(), + inferenceServiceSettings, + modelRegistry.get(), + services.client() + ) + ); + authorizationTaskExecutorRef.set(authTaskExecutor); + var sageMakerSchemas = new SageMakerSchemas(); var sageMakerConfigurations = new LazyInitializable<>(new SageMakerConfiguration(sageMakerSchemas)); inferenceServices.add( () -> List.of( context -> new ElasticInferenceService( - elasicInferenceServiceFactory.get(), + elasticInferenceServiceFactory.get(), serviceComponents.get(), inferenceServiceSettings, - modelRegistry.get(), - authorizationHandler, context ), context -> new SageMakerService( @@ -407,7 +431,7 @@ public Collection createComponents(PluginServices services) { ); components.add(inferenceStatsBinding); components.add(authorizationHandler); - components.add(new PluginComponentBinding<>(Sender.class, elasicInferenceServiceFactory.get().createSender())); + components.add(new PluginComponentBinding<>(Sender.class, elasticInferenceServiceFactory.get().createSender())); components.add( new InferenceEndpointRegistry( services.clusterService(), @@ -418,6 +442,8 @@ public Collection createComponents(PluginServices services) { services.featureService() ) ); + + components.add(authTaskExecutor); components.addAll(createCCMComponents(services)); return components; @@ -429,6 +455,17 @@ private Collection createCCMComponents(PluginServices services) { return List.of(new CCMService(ccmPersistentStorageService), ccmFeature.get(), ccmPersistentStorageService); } + @Override + public List> getPersistentTasksExecutor( + ClusterService clusterService, + ThreadPool threadPool, + Client client, + SettingsModule settingsModule, + IndexNameExpressionResolver expressionResolver + ) { + return List.of(authorizationTaskExecutorRef.get()); + } + @Override public void loadExtensions(ExtensionLoader loader) { inferenceServiceExtensions = loader.loadExtensions(InferenceServiceExtension.class); @@ -462,54 +499,52 @@ public List getInferenceServiceFactories() { @Override public List getNamedWriteables() { - var entries = new ArrayList<>(InferenceNamedWriteablesProvider.getNamedWriteables()); - entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new)); - entries.add(new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new)); - entries.add(new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new)); - entries.add(new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new)); - entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom)); - entries.add( - new NamedWriteableRegistry.Entry( - QueryBuilder.class, - InterceptedInferenceMatchQueryBuilder.NAME, - InterceptedInferenceMatchQueryBuilder::new - ) - ); - entries.add( - new NamedWriteableRegistry.Entry( - QueryBuilder.class, - InterceptedInferenceKnnVectorQueryBuilder.NAME, - InterceptedInferenceKnnVectorQueryBuilder::new - ) - ); - entries.add( - new NamedWriteableRegistry.Entry( - QueryBuilder.class, - InterceptedInferenceSparseVectorQueryBuilder.NAME, - InterceptedInferenceSparseVectorQueryBuilder::new - ) - ); - return entries; + return Stream.of( + List.of( + new NamedWriteableRegistry.Entry(RankBuilder.class, TextSimilarityRankBuilder.NAME, TextSimilarityRankBuilder::new), + new NamedWriteableRegistry.Entry(RankBuilder.class, RandomRankBuilder.NAME, RandomRankBuilder::new), + new NamedWriteableRegistry.Entry(RankDoc.class, TextSimilarityRankDoc.NAME, TextSimilarityRankDoc::new), + new NamedWriteableRegistry.Entry(Metadata.ProjectCustom.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::new), + new NamedWriteableRegistry.Entry(NamedDiff.class, ModelRegistryMetadata.TYPE, ModelRegistryMetadata::readDiffFrom), + new NamedWriteableRegistry.Entry( + QueryBuilder.class, + InterceptedInferenceMatchQueryBuilder.NAME, + InterceptedInferenceMatchQueryBuilder::new + ), + new NamedWriteableRegistry.Entry( + QueryBuilder.class, + InterceptedInferenceKnnVectorQueryBuilder.NAME, + InterceptedInferenceKnnVectorQueryBuilder::new + ), + new NamedWriteableRegistry.Entry( + QueryBuilder.class, + InterceptedInferenceSparseVectorQueryBuilder.NAME, + InterceptedInferenceSparseVectorQueryBuilder::new + ) + ), + InferenceNamedWriteablesProvider.getNamedWriteables(), + AuthorizationTaskExecutor.getNamedWriteables() + ).flatMap(List::stream).toList(); + } @Override public List getNamedXContent() { - List namedXContent = new ArrayList<>(); - namedXContent.add( - new NamedXContentRegistry.Entry( - Metadata.ProjectCustom.class, - new ParseField(ModelRegistryMetadata.TYPE), - ModelRegistryMetadata::fromXContent - ) - ); - namedXContent.add( - new NamedXContentRegistry.Entry( - Metadata.ProjectCustom.class, - new ParseField(ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME), - ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent - ) - ); - return namedXContent; + return Stream.of( + List.of( + new NamedXContentRegistry.Entry( + Metadata.ProjectCustom.class, + new ParseField(ModelRegistryMetadata.TYPE), + ModelRegistryMetadata::fromXContent + ), + new NamedXContentRegistry.Entry( + Metadata.ProjectCustom.class, + new ParseField(ClearInferenceEndpointCacheAction.InvalidateCacheMetadata.NAME), + ClearInferenceEndpointCacheAction.InvalidateCacheMetadata::fromXContent + ) + ), + AuthorizationTaskExecutor.getNamedXContentParsers() + ).flatMap(List::stream).toList(); } @Override @@ -644,7 +679,7 @@ public Map getMetadataMappers() { // Overridable for tests protected Supplier getModelRegistry() { - return () -> modelRegistry.get(); + return modelRegistry::get; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 512ad5a445b18..23ef7481166d9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -221,7 +221,7 @@ private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterStat } private boolean isInferenceIdReserved(String inferenceEndpointId) { - return modelRegistry.containsDefaultConfigId(inferenceEndpointId); + return modelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId); } private static String buildErrorString(String inferenceEndpointId, Set pipelines, Set indexes) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java index 609a1e4df62d8..4415da2c1b99a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java @@ -265,7 +265,7 @@ private Map createStatsKeysWithEndpointCountsForDefa // may only happen for external services. Set modelIds = endpoints.stream() .filter(endpoint -> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(endpoint.getTaskType())) - .filter(endpoint -> modelRegistry.containsDefaultConfigId(endpoint.getInferenceEntityId())) + .filter(endpoint -> modelRegistry.containsPreconfiguredInferenceEndpointId(endpoint.getInferenceEntityId())) .filter(endpoint -> endpoint.getServiceSettings().modelId() != null) .map(endpoint -> stripLinuxSuffix(endpoint.getServiceSettings().modelId())) .collect(Collectors.toSet()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index b472beebb66c5..ded66cb8c44ad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -13,7 +13,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.TransportMasterNodeAction; -import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; @@ -85,8 +84,7 @@ public TransportPutInferenceModelAction( ModelRegistry modelRegistry, InferenceServiceRegistry serviceRegistry, Settings settings, - ProjectResolver projectResolver, - Client client + ProjectResolver projectResolver ) { super( PutInferenceModelAction.NAME, @@ -114,7 +112,7 @@ protected void masterOperation( ClusterState state, ActionListener listener ) throws Exception { - if (modelRegistry.containsDefaultConfigId(request.getInferenceEntityId())) { + if (modelRegistry.containsPreconfiguredInferenceEndpointId(request.getInferenceEntityId())) { listener.onFailure( new ElasticsearchStatusException( "[{}] is a reserved inference ID. Cannot create a new inference endpoint with a reserved ID.", diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java new file mode 100644 index 0000000000000..96905892c5c4f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportStoreEndpointsAction.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.action.support.master.TransportMasterNodeAction; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.block.ClusterBlockException; +import org.elasticsearch.cluster.block.ClusterBlockLevel; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; + +import java.util.List; +import java.util.Objects; + +/** + * Handles the internal action for creating multiple inference endpoints. This should not be used by external REST APIs. + */ +public class TransportStoreEndpointsAction extends TransportMasterNodeAction< + StoreInferenceEndpointsAction.Request, + StoreInferenceEndpointsAction.Response> { + + private final ModelRegistry modelRegistry; + + @Inject + public TransportStoreEndpointsAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + ModelRegistry modelRegistry + ) { + super( + StoreInferenceEndpointsAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + StoreInferenceEndpointsAction.Request::new, + StoreInferenceEndpointsAction.Response::new, + EsExecutors.DIRECT_EXECUTOR_SERVICE + ); + + this.modelRegistry = Objects.requireNonNull(modelRegistry); + } + + @Override + protected void masterOperation( + Task task, + StoreInferenceEndpointsAction.Request request, + ClusterState state, + ActionListener masterListener + ) { + SubscribableListener.>newForked( + listener -> modelRegistry.storeModels(request.getModels(), listener, request.masterNodeTimeout()) + ).andThenApply(StoreInferenceEndpointsAction.Response::new).addListener(masterListener); + } + + @Override + protected ClusterBlockException checkBlock(StoreInferenceEndpointsAction.Request request, ClusterState state) { + return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 390c32bb773f8..739d8d460a8b3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -121,7 +121,7 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOffsetsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.DEFAULT_ELSER_ENDPOINT_ID_V2; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService.DEFAULT_ELSER_ID; /** @@ -171,7 +171,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie * This enables automatic selection of EIS for better performance while maintaining compatibility with on-prem deployments. */ private static String getPreferredElserInferenceId(ModelRegistry modelRegistry) { - if (modelRegistry != null && modelRegistry.containsDefaultConfigId(DEFAULT_EIS_ELSER_INFERENCE_ID)) { + if (modelRegistry != null && modelRegistry.containsPreconfiguredInferenceEndpointId(DEFAULT_EIS_ELSER_INFERENCE_ID)) { return DEFAULT_EIS_ELSER_INFERENCE_ID; } return DEFAULT_FALLBACK_ELSER_INFERENCE_ID; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java index 66d8db95fd267..cf731b22807e2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistry.java @@ -73,9 +73,11 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; +import org.elasticsearch.xpack.core.inference.results.ModelStoreResponse; import org.elasticsearch.xpack.inference.InferenceIndex; import org.elasticsearch.xpack.inference.InferenceSecretsIndex; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; import java.io.IOException; import java.util.ArrayList; @@ -87,10 +89,12 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.stream.Collectors; @@ -147,10 +151,11 @@ public static UnparsedModel unparsedModelFromMap(ModelConfigMap modelConfigMap) private final MasterServiceTaskQueue metadataTaskQueue; private final AtomicBoolean upgradeMetadataInProgress = new AtomicBoolean(false); private final Set preventDeletionLock = Collections.newSetFromMap(new ConcurrentHashMap<>()); - - private volatile Metadata lastMetadata; + private final ClusterService clusterService; + private final AtomicReference lastMetadata = new AtomicReference<>(); public ModelRegistry(ClusterService clusterService, Client client) { + this.clusterService = Objects.requireNonNull(clusterService); this.client = new OriginSettingClient(client, ClientHelper.INFERENCE_ORIGIN); this.defaultConfigIds = new ConcurrentHashMap<>(); var executor = new SimpleBatchedAckListenerTaskExecutor() { @@ -166,13 +171,24 @@ public Tuple executeTask(MetadataTask tas } /** - * Returns true if the provided inference entity id is the same as one of the default - * endpoints ids. + * Returns true if the model registry contains (whether it has persisted it or not) the provided inference entity id. + * EIS preconfigured endpoints are also considered. * @param inferenceEntityId the id to search for * @return true if we find a match and false if not */ - public boolean containsDefaultConfigId(String inferenceEntityId) { - return defaultConfigIds.containsKey(inferenceEntityId); + public boolean containsPreconfiguredInferenceEndpointId(String inferenceEntityId) { + if (defaultConfigIds.containsKey(inferenceEntityId)) { + return true; + } + + if (lastMetadata.get() != null) { + var project = lastMetadata.get().getProject(ProjectId.DEFAULT); + var state = ModelRegistryMetadata.fromState(project); + var eisPreconfiguredEndpoints = state.getServiceInferenceIds(ElasticInferenceService.NAME); + return eisPreconfiguredEndpoints.contains(inferenceEntityId); + } + + return false; } /** @@ -225,16 +241,15 @@ public void clearDefaultIds() { * @throws ResourceNotFoundException if the specified id is guaranteed to not exist in the cluster. */ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) throws ResourceNotFoundException { - synchronized (this) { - if (lastMetadata == null) { - throw new IllegalStateException("initial cluster state not set yet"); - } + if (lastMetadata.get() == null) { + throw new IllegalStateException("initial cluster state not set yet"); } + var config = defaultConfigIds.get(inferenceEntityId); if (config != null) { return config.settings(); } - var project = lastMetadata.getProject(ProjectId.DEFAULT); + var project = lastMetadata.get().getProject(ProjectId.DEFAULT); var state = ModelRegistryMetadata.fromState(project); var existing = state.getMinimalServiceSettings(inferenceEntityId); if (state.isUpgraded() && existing == null) { @@ -243,6 +258,19 @@ public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId return existing; } + public Set getInferenceIds() { + Set metadataInferenceIds = Set.of(); + if (lastMetadata.get() != null) { + var project = lastMetadata.get().getProject(ProjectId.DEFAULT); + var state = ModelRegistryMetadata.fromState(project); + metadataInferenceIds = state.getInferenceIds(); + } + + var ids = new HashSet<>(metadataInferenceIds); + ids.addAll(Set.copyOf(defaultConfigIds.keySet())); + return ids; + } + /** * Get a model with its secret settings * @param inferenceEntityId Model to get @@ -684,12 +712,6 @@ private void storeModel(Model model, boolean updateClusterState, ActionListener< }), timeout); } - public record ModelStoreResponse(String inferenceId, RestStatus status, @Nullable Exception failureCause) { - public boolean failed() { - return failureCause != null; - } - } - public void storeModels(List models, ActionListener> listener, TimeValue timeout) { storeModels(models, true, listener, timeout); } @@ -934,6 +956,14 @@ private void updateClusterState(List models, ActionListener inferenceEntityIds, ActionListener listener) { if (inferenceEntityIds.isEmpty()) { listener.onResponse(true); @@ -1125,11 +1155,9 @@ static List taskTypeMatchedDefaults( @Override public void clusterChanged(ClusterChangedEvent event) { - if (lastMetadata == null || event.metadataChanged()) { + if (lastMetadata.get() == null || event.metadataChanged()) { // keep track of the last applied cluster state - synchronized (this) { - lastMetadata = event.state().metadata(); - } + lastMetadata.set(event.state().metadata()); } if (event.localNodeMaster() == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java index 4bf23103af5a1..359be95d8a4b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadata.java @@ -30,6 +30,7 @@ import java.util.Collection; import java.util.Collections; import java.util.EnumSet; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -150,24 +151,48 @@ public ModelRegistryMetadata withUpgradedModels(Map modelMap; + private final Map> serviceToInferenceEndpointIds; private final Set tombstones; public ModelRegistryMetadata(ImmutableOpenMap modelMap) { - this.isUpgraded = true; - this.modelMap = modelMap; - this.tombstones = null; + this(modelMap, null, true); } public ModelRegistryMetadata(ImmutableOpenMap modelMap, Set tombstone) { - this.isUpgraded = false; - this.modelMap = modelMap; - this.tombstones = Collections.unmodifiableSet(tombstone); + this(modelMap, Collections.unmodifiableSet(tombstone), false); } public ModelRegistryMetadata(StreamInput in) throws IOException { this.isUpgraded = in.readBoolean(); this.modelMap = in.readImmutableOpenMap(StreamInput::readString, MinimalServiceSettings::new); this.tombstones = isUpgraded ? null : in.readCollectionAsSet(StreamInput::readString); + this.serviceToInferenceEndpointIds = buildServiceToInferenceEndpointIdsMap(modelMap); + } + + private ModelRegistryMetadata(ImmutableOpenMap modelMap, Set tombstones, boolean isUpgraded) { + this.isUpgraded = isUpgraded; + this.modelMap = modelMap; + this.tombstones = tombstones; + this.serviceToInferenceEndpointIds = buildServiceToInferenceEndpointIdsMap(modelMap); + } + + private static Map> buildServiceToInferenceEndpointIdsMap( + ImmutableOpenMap modelMap + ) { + var serviceToInferenceIds = new HashMap>(); + for (var entry : modelMap.entrySet()) { + var settings = entry.getValue(); + var serviceName = settings.service(); + + var existingSet = serviceToInferenceIds.get(serviceName); + if (existingSet == null) { + existingSet = new HashSet<>(); + } + + existingSet.add(entry.getKey()); + serviceToInferenceIds.put(serviceName, existingSet); + } + return serviceToInferenceIds; } @Override @@ -221,10 +246,25 @@ public ImmutableOpenMap getModelMap() { return modelMap; } + /** + * Returns all inference entity IDs for a given service. + */ + public Set getServiceInferenceIds(String service) { + if (serviceToInferenceEndpointIds.containsKey(service) == false) { + return Set.of(); + } + + return Set.copyOf(serviceToInferenceEndpointIds.get(service)); + } + public MinimalServiceSettings getMinimalServiceSettings(String inferenceEntityId) { return modelMap.get(inferenceEntityId); } + public Set getInferenceIds() { + return Set.copyOf(modelMap.keySet()); + } + @Override public Diff diff(Metadata.ProjectCustom before) { return new ModelRegistryMetadataDiff((ModelRegistryMetadata) before, this); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 5d476955a7ad6..8a6d2626ad521 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -16,17 +16,13 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; @@ -48,22 +44,16 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.elastic.action.ElasticInferenceServiceActionCreator; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationHandler; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; -import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.telemetry.TraceContext; import java.util.EnumSet; @@ -95,7 +85,7 @@ public class ElasticInferenceService extends SenderService { // A batch size of 16 provides optimal throughput and stability, especially on lower-tier instance types. public static final Integer SPARSE_TEXT_EMBEDDING_MAX_BATCH_SIZE = 16; - private static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( + public static final EnumSet IMPLEMENTED_TASK_TYPES = EnumSet.of( TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.RERANK, @@ -108,22 +98,6 @@ public class ElasticInferenceService extends SenderService { // This mirrors the memory constraints observed with sparse embeddings private static final Integer DENSE_TEXT_EMBEDDINGS_MAX_BATCH_SIZE = 16; - // rainbow-sprinkles - static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; - static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); - - // elser-2 - static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; - public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId("elser-2"); - - // multilingual-text-embed - static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "jina-embeddings-v3"; - static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; - - // rerank-v1 - static final String DEFAULT_RERANK_MODEL_ID_V1 = "elastic-rerank-v1"; - static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; - /** * The task types that the {@link InferenceAction.Request} can accept. */ @@ -133,129 +107,27 @@ public class ElasticInferenceService extends SenderService { TaskType.TEXT_EMBEDDING ); - public static String defaultEndpointId(String modelId) { - return Strings.format(".%s-elastic", modelId); - } - private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; - private final ElasticInferenceServiceAuthorizationHandler authorizationHandler; public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, InferenceServiceExtension.InferenceServiceFactoryContext context ) { - this( - factory, - serviceComponents, - elasticInferenceServiceSettings, - modelRegistry, - authorizationRequestHandler, - context.clusterService() - ); + this(factory, serviceComponents, elasticInferenceServiceSettings, context.clusterService()); } public ElasticInferenceService( HttpRequestSender.Factory factory, ServiceComponents serviceComponents, ElasticInferenceServiceSettings elasticInferenceServiceSettings, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, ClusterService clusterService ) { super(factory, serviceComponents, clusterService); this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( elasticInferenceServiceSettings.getElasticInferenceServiceUrl() ); - authorizationHandler = new ElasticInferenceServiceAuthorizationHandler( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - initDefaultEndpoints(elasticInferenceServiceComponents), - IMPLEMENTED_TASK_TYPES, - this, - getSender(), - elasticInferenceServiceSettings - ); - } - - private static Map initDefaultEndpoints( - ElasticInferenceServiceComponents elasticInferenceServiceComponents - ) { - return Map.of( - DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, - new DefaultModelConfig( - new ElasticInferenceServiceCompletionModel( - DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, - TaskType.CHAT_COMPLETION, - NAME, - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents - ), - MinimalServiceSettings.chatCompletion(NAME) - ), - DEFAULT_ELSER_2_MODEL_ID, - new DefaultModelConfig( - new ElasticInferenceServiceSparseEmbeddingsModel( - DEFAULT_ELSER_ENDPOINT_ID_V2, - TaskType.SPARSE_EMBEDDING, - NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.sparseEmbedding(NAME) - ), - DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, - new DefaultModelConfig( - new ElasticInferenceServiceDenseTextEmbeddingsModel( - DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, - TaskType.TEXT_EMBEDDING, - NAME, - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, - defaultDenseTextEmbeddingsSimilarity(), - null, - null - ), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.textEmbedding( - NAME, - DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ) - ), - DEFAULT_RERANK_MODEL_ID_V1, - new DefaultModelConfig( - new ElasticInferenceServiceRerankModel( - DEFAULT_RERANK_ENDPOINT_ID_V1, - TaskType.RERANK, - NAME, - new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - elasticInferenceServiceComponents - ), - MinimalServiceSettings.rerank(NAME) - ) - ); - } - - @Override - public void onNodeStarted() { - authorizationHandler.init(); } @Override @@ -270,32 +142,11 @@ protected void validateRerankParameters(Boolean returnDocuments, Integer topN, V } } - /** - * Only use this in tests. - * - * Waits the specified amount of time for the authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForFirstAuthorizationToComplete(TimeValue waitTime) { - authorizationHandler.waitForAuthorizationToComplete(waitTime); - } - @Override public Set supportedStreamingTasks() { return EnumSet.of(TaskType.CHAT_COMPLETION); } - @Override - public List defaultConfigIds() { - return authorizationHandler.defaultConfigIds(); - } - - @Override - public void defaultConfigs(ActionListener> defaultsListener) { - authorizationHandler.defaultConfigs(defaultsListener); - } - @Override protected void doUnifiedCompletionInfer( Model model, @@ -472,7 +323,9 @@ public InferenceServiceConfiguration getConfiguration() { @Override public EnumSet supportedTaskTypes() { - return authorizationHandler.supportedTaskTypes(); + throw new UnsupportedOperationException( + "The EIS supported task types change depending on authorization, requests should be made directly to EIS instead" + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java index 34a8086119150..9f5f6d1b75dfc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceModel.java @@ -15,7 +15,7 @@ import java.util.Objects; -public abstract class ElasticInferenceServiceModel extends RateLimitGroupingModel { +public class ElasticInferenceServiceModel extends RateLimitGroupingModel { private final ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings; @@ -53,4 +53,18 @@ public RateLimitSettings rateLimitSettings() { public ElasticInferenceServiceComponents elasticInferenceServiceComponents() { return elasticInferenceServiceComponents; } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + if (super.equals(o) == false) return false; + ElasticInferenceServiceModel that = (ElasticInferenceServiceModel) o; + return Objects.equals(rateLimitServiceSettings, that.rateLimitServiceSettings) + && Objects.equals(elasticInferenceServiceComponents, that.elasticInferenceServiceComponents); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), rateLimitServiceSettings, elasticInferenceServiceComponents); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java index 0d8bef246b35d..fcc9808c7b8ff 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSettings.java @@ -28,7 +28,7 @@ public class ElasticInferenceServiceSettings { @Deprecated static final Setting EIS_GATEWAY_URL = Setting.simpleString("xpack.inference.eis.gateway.url", Setting.Property.NodeScope); - static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( + public static final Setting ELASTIC_INFERENCE_SERVICE_URL = Setting.simpleString( "xpack.inference.elastic.url", Setting.Property.NodeScope ); @@ -37,7 +37,7 @@ public class ElasticInferenceServiceSettings { * This setting is for testing only. It controls whether authorization is only performed once at bootup. If set to true, an * authorization request will be made repeatedly on an interval. */ - static final Setting PERIODIC_AUTHORIZATION_ENABLED = Setting.boolSetting( + public static final Setting PERIODIC_AUTHORIZATION_ENABLED = Setting.boolSetting( "xpack.inference.elastic.periodic_authorization_enabled", true, Setting.Property.NodeScope diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java new file mode 100644 index 0000000000000..f74eac700465e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/InternalPreconfiguredEndpoints.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic; + +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; + +import java.util.Map; +import java.util.Set; + +import static java.util.stream.Collectors.toMap; + +/** + * Represents the preconfigured endpoints that are included in Elasticsearch. EIS will support dynamic preconfigured endpoints which means + * it can provide new preconfigured endpoints that do not exist in the source here. + */ +public class InternalPreconfiguredEndpoints { + + // rainbow-sprinkles + public static final String DEFAULT_CHAT_COMPLETION_MODEL_ID_V1 = "rainbow-sprinkles"; + public static final String DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1 = ".rainbow-sprinkles-elastic"; + + // elser-2 + public static final String DEFAULT_ELSER_2_MODEL_ID = "elser_model_2"; + public static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = ".elser-2-elastic"; + + // multilingual-text-embed + public static final Integer DENSE_TEXT_EMBEDDINGS_DIMENSIONS = 1024; + public static final String DEFAULT_MULTILINGUAL_EMBED_MODEL_ID = "jina-embeddings-v3"; + public static final String DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID = ".jina-embeddings-v3"; + + // rerank-v1 + public static final String DEFAULT_RERANK_MODEL_ID_V1 = "elastic-rerank-v1"; + public static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = ".elastic-rerank-v1"; + + public record MinimalModel( + ModelConfigurations configurations, + ElasticInferenceServiceRateLimitServiceSettings rateLimitServiceSettings + ) {} + + private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SERVICE_SETTINGS = + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_EMBEDDINGS_SERVICE_SETTINGS = + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null); + private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS = + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + null + ); + private static final ElasticInferenceServiceRerankServiceSettings RERANK_SERVICE_SETTINGS = + new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1); + + private static final Map MODEL_NAME_TO_MINIMAL_MODEL = Map.of( + DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, + new MinimalModel( + new ModelConfigurations( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + COMPLETION_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + COMPLETION_SERVICE_SETTINGS + ), + DEFAULT_ELSER_2_MODEL_ID, + new MinimalModel( + new ModelConfigurations( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + SPARSE_EMBEDDINGS_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + SPARSE_EMBEDDINGS_SERVICE_SETTINGS + ), + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + new MinimalModel( + new ModelConfigurations( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + DENSE_TEXT_EMBEDDINGS_SERVICE_SETTINGS + ), + DEFAULT_RERANK_MODEL_ID_V1, + new MinimalModel( + new ModelConfigurations( + DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + RERANK_SERVICE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + RERANK_SERVICE_SETTINGS + ) + ); + + private static final Map INFERENCE_ID_TO_MINIMAL_MODEL = MODEL_NAME_TO_MINIMAL_MODEL.entrySet() + .stream() + .collect(toMap(e -> e.getValue().configurations().getInferenceEntityId(), Map.Entry::getValue)); + + public static final Set EIS_PRECONFIGURED_ENDPOINT_IDS = Set.copyOf(INFERENCE_ID_TO_MINIMAL_MODEL.keySet()); + + public static SimilarityMeasure defaultDenseTextEmbeddingsSimilarity() { + return SimilarityMeasure.COSINE; + } + + public static MinimalModel getWithModelName(String modelName) { + return MODEL_NAME_TO_MINIMAL_MODEL.get(modelName); + } + + public static MinimalModel getWithInferenceId(String inferenceId) { + return INFERENCE_ID_TO_MINIMAL_MODEL.get(inferenceId); + } + + private InternalPreconfiguredEndpoints() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java new file mode 100644 index 0000000000000..16646cbde4e89 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPoller.java @@ -0,0 +1,288 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; + +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.IMPLEMENTED_TASK_TYPES; + +public class AuthorizationPoller extends AllocatedPersistentTask { + + public static final String TASK_NAME = "eis-authorization-poller"; + + private static final Logger logger = LogManager.getLogger(AuthorizationPoller.class); + + private final ServiceComponents serviceComponents; + private final ModelRegistry modelRegistry; + private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; + private final Sender sender; + private final Runnable callback; + private final AtomicReference lastAuthTask = new AtomicReference<>(null); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; + private final AtomicBoolean initialized = new AtomicBoolean(false); + private final ElasticInferenceServiceComponents elasticInferenceServiceComponents; + private final Client client; + private final CountDownLatch receivedFirstAuthResponseLatch = new CountDownLatch(1); + + public record TaskFields(long id, String type, String action, String description, TaskId parentTask, Map headers) {} + + public record Parameters( + ServiceComponents serviceComponents, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ModelRegistry modelRegistry, + Client client + ) {} + + public static AuthorizationPoller create(TaskFields taskFields, Parameters parameters) { + return new AuthorizationPoller(Objects.requireNonNull(taskFields), Objects.requireNonNull(parameters)); + } + + private AuthorizationPoller(TaskFields taskFields, Parameters parameters) { + this( + taskFields, + parameters.serviceComponents, + parameters.authorizationRequestHandler, + parameters.sender, + parameters.elasticInferenceServiceSettings, + parameters.modelRegistry, + parameters.client, + null + ); + } + + // default for testing + AuthorizationPoller( + TaskFields taskFields, + ServiceComponents serviceComponents, + ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, + Sender sender, + ElasticInferenceServiceSettings elasticInferenceServiceSettings, + ModelRegistry modelRegistry, + Client client, + // this is a hack to facilitate testing + Runnable callback + ) { + super(taskFields.id, taskFields.type, taskFields.action, taskFields.description, taskFields.parentTask, taskFields.headers); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); + this.sender = Objects.requireNonNull(sender); + this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); + this.elasticInferenceServiceComponents = new ElasticInferenceServiceComponents( + elasticInferenceServiceSettings.getElasticInferenceServiceUrl() + ); + this.modelRegistry = Objects.requireNonNull(modelRegistry); + this.client = new OriginSettingClient(Objects.requireNonNull(client), ClientHelper.INFERENCE_ORIGIN); + this.callback = callback; + } + + public void start() { + if (initialized.compareAndSet(false, true)) { + logger.debug("Initializing EIS authorization logic"); + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); + } + } + + /** + * This should only be used for testing to wait for the first authorization response to be received. + */ + public void waitForAuthorizationToComplete(TimeValue waitTime) { + try { + if (receivedFirstAuthResponseLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { + throw new IllegalStateException("The wait time has expired for first authorization response to be received."); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Waiting for first authorization response to complete was interrupted"); + } + } + + // Overriding so tests in the same package can access + @Override + protected void init( + PersistentTasksService persistentTasksService, + TaskManager taskManager, + String persistentTaskId, + long allocationId + ) { + super.init(persistentTasksService, taskManager, persistentTaskId, allocationId); + } + + @Override + protected void onCancelled() { + shutdown(); + markAsCompleted(); + } + + private void shutdownAndMarkTaskAsFailed(Exception e) { + shutdown(); + markAsFailed(e); + } + + // default for testing + void shutdown() { + shutdown.set(true); + + var authTask = lastAuthTask.get(); + if (authTask != null) { + authTask.cancel(); + } + } + + // default for testing + boolean isShutdown() { + return shutdown.get(); + } + + private void scheduleAuthorizationRequest() { + try { + if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { + return; + } + + // this call has to be on the individual thread otherwise we get an exception + var random = Randomness.get(); + var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); + var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); + + logger.debug( + () -> Strings.format( + "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", + elasticInferenceServiceSettings.getAuthRequestInterval().millis(), + jitter + ) + ); + logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); + + lastAuthTask.set( + serviceComponents.threadPool() + .schedule( + this::scheduleAndSendAuthorizationRequest, + waitTime, + serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) + ) + ); + } catch (Exception e) { + logger.warn("Failed scheduling authorization request", e); + // Shutdown and complete the task so it will be restarted + shutdownAndMarkTaskAsFailed(e); + } + } + + private void scheduleAndSendAuthorizationRequest() { + if (shutdown.get()) { + return; + } + + scheduleAuthorizationRequest(); + sendAuthorizationRequest(); + } + + // default for testing + void sendAuthorizationRequest() { + if (modelRegistry.isReady() == false) { + return; + } + + var finalListener = ActionListener.running(() -> { + if (callback != null) { + callback.run(); + } + receivedFirstAuthResponseLatch.countDown(); + }).delegateResponse((delegate, e) -> { + logger.atWarn().withThrowable(e).log("Failed processing EIS preconfigured endpoints"); + delegate.onResponse(null); + }); + + SubscribableListener.newForked( + authModelListener -> authorizationHandler.getAuthorization(authModelListener, sender) + ) + .andThenApply(this::getNewInferenceEndpointsToStore) + .andThen((storeListener, newInferenceIds) -> storePreconfiguredModels(newInferenceIds, storeListener)) + .addListener(finalListener); + } + + private Set getNewInferenceEndpointsToStore(ElasticInferenceServiceAuthorizationModel authModel) { + var scopedAuthModel = authModel.newLimitedToTaskTypes(EnumSet.copyOf(IMPLEMENTED_TASK_TYPES)); + + var authorizedModelIds = scopedAuthModel.getAuthorizedModelIds(); + var existingInferenceIds = modelRegistry.getInferenceIds(); + + var newInferenceIds = authorizedModelIds.stream() + .map(InternalPreconfiguredEndpoints::getWithModelName) + .filter(Objects::nonNull) + .map(model -> model.configurations().getInferenceEntityId()) + .collect(Collectors.toSet()); + + newInferenceIds.removeAll(existingInferenceIds); + return newInferenceIds; + } + + private void storePreconfiguredModels(Set newInferenceIds, ActionListener listener) { + if (newInferenceIds.isEmpty()) { + listener.onResponse(null); + return; + } + + logger.info("Storing new EIS preconfigured inference endpoints with inference IDs {}", newInferenceIds); + var modelsToAdd = PreconfiguredEndpointModelAdapter.getModels(newInferenceIds, elasticInferenceServiceComponents); + var storeRequest = new StoreInferenceEndpointsAction.Request(modelsToAdd, TimeValue.THIRTY_SECONDS); + + ActionListener logResultsListener = ActionListener.wrap(responses -> { + for (var response : responses.getResults()) { + if (response.failed()) { + logger.atWarn() + .withThrowable(response.failureCause()) + .log("Failed to store new EIS preconfigured inference endpoint with inference ID [{}]", response.inferenceId()); + } else { + logger.atInfo() + .log("Successfully stored EIS preconfigured inference endpoint with inference ID [{}]", response.inferenceId()); + } + } + }, e -> logger.atWarn().withThrowable(e).log("Failed to store new EIS preconfigured inference endpoints [{}]", newInferenceIds)); + + client.execute( + StoreInferenceEndpointsAction.INSTANCE, + storeRequest, + ActionListener.runAfter(logResultsListener, () -> listener.onResponse(null)) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java new file mode 100644 index 0000000000000..bca830eb7b948 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutor.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ResourceAlreadyExistsException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.core.FixForMultiProject; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.persistent.AllocatedPersistentTask; +import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.persistent.PersistentTaskState; +import org.elasticsearch.persistent.PersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksExecutor; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.transport.RemoteTransportException; +import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; +import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; + +public class AuthorizationTaskExecutor extends PersistentTasksExecutor implements ClusterStateListener { + + private static final Logger logger = LogManager.getLogger(AuthorizationTaskExecutor.class); + + private final ClusterService clusterService; + private final PersistentTasksService persistentTasksService; + private final AuthorizationPoller.Parameters pollerParameters; + private final AtomicReference currentTask = new AtomicReference<>(); + + public static AuthorizationTaskExecutor create(ClusterService clusterService, AuthorizationPoller.Parameters parameters) { + Objects.requireNonNull(clusterService); + Objects.requireNonNull(parameters); + + var executor = new AuthorizationTaskExecutor( + clusterService, + new PersistentTasksService(clusterService, parameters.serviceComponents().threadPool(), parameters.client()), + parameters + ); + executor.init(); + return executor; + } + + // default for testing + AuthorizationTaskExecutor( + ClusterService clusterService, + PersistentTasksService persistentTasksService, + AuthorizationPoller.Parameters pollerParameters + ) { + super(TASK_NAME, pollerParameters.serviceComponents().threadPool().executor(UTILITY_THREAD_POOL_NAME)); + this.clusterService = Objects.requireNonNull(clusterService); + this.persistentTasksService = Objects.requireNonNull(persistentTasksService); + this.pollerParameters = Objects.requireNonNull(pollerParameters); + } + + // default for testing + void init() { + // If the EIS url is not configured, then we won't be able to interact with the service, so don't start the task. + if (Strings.isNullOrEmpty(pollerParameters.elasticInferenceServiceSettings().getElasticInferenceServiceUrl()) == false) { + clusterService.addListener(this); + } + } + + /** + * This method should only be used for testing purposes to get the current running task. + */ + public AuthorizationPoller getCurrentPollerTask() { + return currentTask.get(); + } + + @Override + protected void nodeOperation(AllocatedPersistentTask task, AuthorizationTaskParams params, PersistentTaskState state) { + var authPoller = (AuthorizationPoller) task; + currentTask.set(authPoller); + authPoller.start(); + } + + @FixForMultiProject( + description = "A single cluster can have multiple projects, " + + "we'll need to either make a call per project/org or use a bulk authorization api that EIS provides" + ) + @Override + public Scope scope() { + return Scope.CLUSTER; + } + + @Override + protected AuthorizationPoller createTask( + long id, + String type, + String action, + TaskId parentTaskId, + PersistentTasksCustomMetadata.PersistentTask taskInProgress, + Map headers + ) { + return AuthorizationPoller.create( + new AuthorizationPoller.TaskFields(id, type, action, getDescription(taskInProgress), parentTaskId, headers), + pollerParameters + ); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + if (authorizationTaskExists(event)) { + return; + } + + persistentTasksService.sendClusterStartRequest( + TASK_NAME, + TASK_NAME, + new AuthorizationTaskParams(), + TimeValue.THIRTY_SECONDS, + ActionListener.wrap(persistentTask -> logger.debug("Created authorization poller task"), exception -> { + var thrownException = exception instanceof RemoteTransportException ? exception.getCause() : exception; + if (thrownException instanceof ResourceAlreadyExistsException == false) { + logger.error("Failed to create authorization poller task", exception); + } + }) + ); + } + + private static boolean authorizationTaskExists(ClusterChangedEvent event) { + return ClusterPersistentTasksCustomMetadata.getTaskWithId(event.state(), TASK_NAME) != null; + } + + public static List getNamedXContentParsers() { + return List.of( + new NamedXContentRegistry.Entry( + PersistentTaskParams.class, + new ParseField(AuthorizationPoller.TASK_NAME), + AuthorizationTaskParams::fromXContent + ) + ); + } + + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(PersistentTaskParams.class, AuthorizationPoller.TASK_NAME, AuthorizationTaskParams::new) + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java new file mode 100644 index 0000000000000..7b2791169e872 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParams.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.persistent.PersistentTaskParams; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +import static org.elasticsearch.xpack.inference.services.elastic.authorization.AuthorizationPoller.TASK_NAME; + +/** + * Empty parameters for the authorization persistent task. + */ +public class AuthorizationTaskParams implements PersistentTaskParams { + public static final AuthorizationTaskParams INSTANCE = new AuthorizationTaskParams(); + + private static final ObjectParser PARSER = new ObjectParser<>(TASK_NAME, true, () -> INSTANCE); + private static final TransportVersion INFERENCE_API_EIS_AUTHORIZATION_PERSISTENT_TASK = TransportVersion.fromName( + "inference_api_eis_authorization_persistent_task" + ); + + AuthorizationTaskParams() {} + + AuthorizationTaskParams(StreamInput in) {} + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + builder.startObject(); + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return TASK_NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return INFERENCE_API_EIS_AUTHORIZATION_PERSISTENT_TASK; + } + + @Override + public void writeTo(StreamOutput out) {} + + public static AuthorizationTaskParams fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object o) { + return this == o || (o != null && getClass() == o.getClass()); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java deleted file mode 100644 index f83542e7fe740..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandler.java +++ /dev/null @@ -1,336 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.Randomness; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.threadpool.Scheduler; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; - -import java.io.Closeable; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.EnumSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Set; -import java.util.TreeSet; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; - -import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; - -public class ElasticInferenceServiceAuthorizationHandler implements Closeable { - private static final Logger logger = LogManager.getLogger(ElasticInferenceServiceAuthorizationHandler.class); - - private record AuthorizedContent( - ElasticInferenceServiceAuthorizationModel taskTypesAndModels, - List configIds, - List defaultModelConfigs - ) { - static AuthorizedContent empty() { - return new AuthorizedContent(ElasticInferenceServiceAuthorizationModel.newDisabledService(), List.of(), List.of()); - } - } - - private final ServiceComponents serviceComponents; - private final AtomicReference authorizedContent = new AtomicReference<>(AuthorizedContent.empty()); - private final ModelRegistry modelRegistry; - private final ElasticInferenceServiceAuthorizationRequestHandler authorizationHandler; - private final Map defaultModelsConfigs; - private final CountDownLatch firstAuthorizationCompletedLatch = new CountDownLatch(1); - private final EnumSet implementedTaskTypes; - private final InferenceService inferenceService; - private final Sender sender; - private final Runnable callback; - private final AtomicReference lastAuthTask = new AtomicReference<>(null); - private final AtomicBoolean shutdown = new AtomicBoolean(false); - private final ElasticInferenceServiceSettings elasticInferenceServiceSettings; - - public ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings - ) { - this( - serviceComponents, - modelRegistry, - authorizationRequestHandler, - defaultModelsConfigs, - implementedTaskTypes, - Objects.requireNonNull(inferenceService), - sender, - elasticInferenceServiceSettings, - null - ); - } - - // default for testing - ElasticInferenceServiceAuthorizationHandler( - ServiceComponents serviceComponents, - ModelRegistry modelRegistry, - ElasticInferenceServiceAuthorizationRequestHandler authorizationRequestHandler, - Map defaultModelsConfigs, - EnumSet implementedTaskTypes, - InferenceService inferenceService, - Sender sender, - ElasticInferenceServiceSettings elasticInferenceServiceSettings, - // this is a hack to facilitate testing - Runnable callback - ) { - this.serviceComponents = Objects.requireNonNull(serviceComponents); - this.modelRegistry = Objects.requireNonNull(modelRegistry); - this.authorizationHandler = Objects.requireNonNull(authorizationRequestHandler); - this.defaultModelsConfigs = Objects.requireNonNull(defaultModelsConfigs); - this.implementedTaskTypes = Objects.requireNonNull(implementedTaskTypes); - // allow the service to be null for testing - this.inferenceService = inferenceService; - this.sender = Objects.requireNonNull(sender); - this.elasticInferenceServiceSettings = Objects.requireNonNull(elasticInferenceServiceSettings); - this.callback = callback; - } - - public void init() { - logger.debug("Initializing authorization logic"); - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME).execute(this::scheduleAndSendAuthorizationRequest); - } - - /** - * Waits the specified amount of time for the first authorization call to complete. This is mainly to make testing easier. - * @param waitTime the max time to wait - * @throws IllegalStateException if the wait time is exceeded or the call receives an {@link InterruptedException} - */ - public void waitForAuthorizationToComplete(TimeValue waitTime) { - try { - if (firstAuthorizationCompletedLatch.await(waitTime.getSeconds(), TimeUnit.SECONDS) == false) { - throw new IllegalStateException("The wait time has expired for authorization to complete."); - } - } catch (InterruptedException e) { - throw new IllegalStateException("Waiting for authorization to complete was interrupted"); - } - } - - public synchronized Set supportedStreamingTasks() { - var authorizedStreamingTaskTypes = EnumSet.of(TaskType.CHAT_COMPLETION); - authorizedStreamingTaskTypes.retainAll(authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes()); - - return authorizedStreamingTaskTypes; - } - - public synchronized List defaultConfigIds() { - return authorizedContent.get().configIds; - } - - public synchronized void defaultConfigs(ActionListener> defaultsListener) { - var models = authorizedContent.get().defaultModelConfigs.stream().map(DefaultModelConfig::model).toList(); - defaultsListener.onResponse(models); - } - - public synchronized EnumSet supportedTaskTypes() { - return authorizedContent.get().taskTypesAndModels.getAuthorizedTaskTypes(); - } - - public synchronized boolean hideFromConfigurationApi() { - return authorizedContent.get().taskTypesAndModels.isAuthorized() == false; - } - - @Override - public void close() throws IOException { - shutdown.set(true); - if (lastAuthTask.get() != null) { - lastAuthTask.get().cancel(); - } - } - - private void scheduleAuthorizationRequest() { - try { - if (elasticInferenceServiceSettings.isPeriodicAuthorizationEnabled() == false) { - return; - } - - // this call has to be on the individual thread otherwise we get an exception - var random = Randomness.get(); - var jitter = (long) (elasticInferenceServiceSettings.getMaxAuthorizationRequestJitter().millis() * random.nextDouble()); - var waitTime = TimeValue.timeValueMillis(elasticInferenceServiceSettings.getAuthRequestInterval().millis() + jitter); - - logger.debug( - () -> Strings.format( - "Scheduling the next authorization call with request interval: %s ms, jitter: %d ms", - elasticInferenceServiceSettings.getAuthRequestInterval().millis(), - jitter - ) - ); - logger.debug(() -> Strings.format("Next authorization call in %d minutes", waitTime.getMinutes())); - - lastAuthTask.set( - serviceComponents.threadPool() - .schedule( - this::scheduleAndSendAuthorizationRequest, - waitTime, - serviceComponents.threadPool().executor(UTILITY_THREAD_POOL_NAME) - ) - ); - } catch (Exception e) { - logger.warn("Failed scheduling authorization request", e); - } - } - - private void scheduleAndSendAuthorizationRequest() { - if (shutdown.get()) { - return; - } - - scheduleAuthorizationRequest(); - sendAuthorizationRequest(); - } - - private void sendAuthorizationRequest() { - try { - ActionListener listener = ActionListener.wrap((model) -> { - setAuthorizedContent(model); - if (callback != null) { - callback.run(); - } - }, e -> { - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - }); - - authorizationHandler.getAuthorization(listener, sender); - } catch (Exception e) { - logger.warn("Failure while sending the request to retrieve authorization", e); - // we don't need to do anything if there was a failure, everything is disabled by default - firstAuthorizationCompletedLatch.countDown(); - } - } - - private synchronized void setAuthorizedContent(ElasticInferenceServiceAuthorizationModel auth) { - logger.debug(() -> Strings.format("Received authorization response, %s", auth)); - - var authorizedTaskTypesAndModels = auth.newLimitedToTaskTypes(EnumSet.copyOf(implementedTaskTypes)); - logger.debug(() -> Strings.format("Authorization entity limited to service task types, %s", authorizedTaskTypesAndModels)); - - // recalculate which default config ids and models are authorized now - var authorizedDefaultModelIds = getAuthorizedDefaultModelIds(authorizedTaskTypesAndModels); - - var authorizedDefaultConfigIds = getAuthorizedDefaultConfigIds(authorizedDefaultModelIds, authorizedTaskTypesAndModels); - var authorizedDefaultModelObjects = getAuthorizedDefaultModelsObjects(authorizedDefaultModelIds); - authorizedContent.set( - new AuthorizedContent(authorizedTaskTypesAndModels, authorizedDefaultConfigIds, authorizedDefaultModelObjects) - ); - - authorizedContent.get().configIds().forEach(modelRegistry::putDefaultIdIfAbsent); - handleRevokedDefaultConfigs(authorizedDefaultModelIds); - } - - private Set getAuthorizedDefaultModelIds(ElasticInferenceServiceAuthorizationModel auth) { - var authorizedModels = auth.getAuthorizedModelIds(); - var authorizedDefaultModelIds = new TreeSet<>(defaultModelsConfigs.keySet()); - authorizedDefaultModelIds.retainAll(authorizedModels); - - return authorizedDefaultModelIds; - } - - private List getAuthorizedDefaultConfigIds( - Set authorizedDefaultModelIds, - ElasticInferenceServiceAuthorizationModel auth - ) { - var authorizedConfigIds = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - if (auth.getAuthorizedTaskTypes().contains(modelConfig.model().getTaskType()) == false) { - logger.warn( - org.elasticsearch.common.Strings.format( - "The authorization response included the default model: %s, " - + "but did not authorize the assumed task type of the model: %s. Enabling model.", - id, - modelConfig.model().getTaskType() - ) - ); - } - authorizedConfigIds.add( - new InferenceService.DefaultConfigId( - modelConfig.model().getInferenceEntityId(), - modelConfig.settings(), - inferenceService - ) - ); - } - } - - authorizedConfigIds.sort(Comparator.comparing(InferenceService.DefaultConfigId::inferenceId)); - return authorizedConfigIds; - } - - private List getAuthorizedDefaultModelsObjects(Set authorizedDefaultModelIds) { - var authorizedModels = new ArrayList(); - for (var id : authorizedDefaultModelIds) { - var modelConfig = defaultModelsConfigs.get(id); - if (modelConfig != null) { - authorizedModels.add(modelConfig); - } - } - - authorizedModels.sort(Comparator.comparing(modelConfig -> modelConfig.model().getInferenceEntityId())); - return authorizedModels; - } - - private void handleRevokedDefaultConfigs(Set authorizedDefaultModelIds) { - // if a model was initially returned in the authorization response but is absent, then we'll assume authorization was revoked - var unauthorizedDefaultModelIds = new HashSet<>(defaultModelsConfigs.keySet()); - unauthorizedDefaultModelIds.removeAll(authorizedDefaultModelIds); - - // get all the default inference endpoint ids for the unauthorized model ids - var unauthorizedDefaultInferenceEndpointIds = unauthorizedDefaultModelIds.stream() - .map(defaultModelsConfigs::get) // get all the model configs - .filter(Objects::nonNull) // limit to only non-null - .map(modelConfig -> modelConfig.model().getInferenceEntityId()) // get the inference ids - .collect(Collectors.toSet()); - - var deleteInferenceEndpointsListener = ActionListener.wrap(result -> { - logger.debug(Strings.format("Successfully revoked access to default inference endpoint IDs: %s", unauthorizedDefaultModelIds)); - firstAuthorizationCompletedLatch.countDown(); - }, e -> { - logger.warn( - Strings.format("Failed to revoke access to default inference endpoint IDs: %s, error: %s", unauthorizedDefaultModelIds, e) - ); - firstAuthorizationCompletedLatch.countDown(); - }); - - logger.debug( - () -> Strings.format( - "Synchronizing default inference endpoints, attempting to remove ids: %s", - unauthorizedDefaultInferenceEndpointIds - ) - ); - modelRegistry.removeDefaultConfigs(unauthorizedDefaultInferenceEndpointIds, deleteInferenceEndpointsListener); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java new file mode 100644 index 0000000000000..ab23da7cab5b2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapter.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; + +import java.util.List; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.EIS_PRECONFIGURED_ENDPOINT_IDS; + +public class PreconfiguredEndpointModelAdapter { + public static List getModels(Set inferenceIds, ElasticInferenceServiceComponents elasticInferenceServiceComponents) { + return inferenceIds.stream() + .sorted() + .filter(EIS_PRECONFIGURED_ENDPOINT_IDS::contains) + .map(id -> createModel(InternalPreconfiguredEndpoints.getWithInferenceId(id), elasticInferenceServiceComponents)) + .toList(); + } + + public static Model createModel( + InternalPreconfiguredEndpoints.MinimalModel minimalModel, + ElasticInferenceServiceComponents elasticInferenceServiceComponents + ) { + return new ElasticInferenceServiceModel( + minimalModel.configurations(), + new ModelSecrets(EmptySecretSettings.INSTANCE), + minimalModel.rateLimitServiceSettings(), + elasticInferenceServiceComponents + ); + } + + private PreconfiguredEndpointModelAdapter() {} +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java index 27952f23f37f8..d53c9d5eebbc9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointActionTests.java @@ -85,7 +85,7 @@ public void testFailsToDelete_ADefaultEndpoint_WithoutPassingForceQueryParameter listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of())); return Void.TYPE; }).when(mockModelRegistry).getModel(anyString(), any()); - when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(anyString())).thenReturn(true); var listener = new PlainActionFuture(); @@ -109,7 +109,7 @@ public void testDeletesDefaultEndpoint_WhenForceIsTrue() { listener.onResponse(new UnparsedModel("model_id", TaskType.COMPLETION, "service", Map.of(), Map.of())); return Void.TYPE; }).when(mockModelRegistry).getModel(anyString(), any()); - when(mockModelRegistry.containsDefaultConfigId(anyString())).thenReturn(true); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(anyString())).thenReturn(true); doAnswer(invocationOnMock -> { ActionListener listener = invocationOnMock.getArgument(1); listener.onResponse(true); @@ -145,7 +145,7 @@ public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() { var taskType = randomFrom(TaskType.values()); var mockService = mock(InferenceService.class); mockUnparsableModel(inferenceEndpointId, serviceName, taskType, mockService); - when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId)).thenReturn(false); var listener = new PlainActionFuture(); action.masterOperation( @@ -160,7 +160,7 @@ public void testFailsToDeleteUnparsableEndpoint_WhenForceIsFalse() { verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService); } @@ -240,7 +240,7 @@ public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() { var serviceName = randomAlphanumericOfLength(10); var taskType = randomFrom(TaskType.values()); mockNoService(inferenceEndpointId, serviceName, taskType); - when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId)).thenReturn(false); var listener = new PlainActionFuture(); @@ -255,7 +255,7 @@ public void testFailsToDeleteEndpointWithNoService_WhenForceIsFalse() { assertThat(exception.getMessage(), containsString("No service found for this inference endpoint")); verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry); } @@ -275,7 +275,7 @@ public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse var mockService = mock(InferenceService.class); var mockModel = mock(Model.class); mockStopDeploymentFails(inferenceEndpointId, serviceName, taskType, mockService, mockModel); - when(mockModelRegistry.containsDefaultConfigId(inferenceEndpointId)).thenReturn(false); + when(mockModelRegistry.containsPreconfiguredInferenceEndpointId(inferenceEndpointId)).thenReturn(false); var listener = new PlainActionFuture(); action.masterOperation( @@ -289,7 +289,7 @@ public void testFailsToDeleteEndpointIfModelDeploymentStopFails_WhenForceIsFalse assertThat(exception.getMessage(), containsString("Failed to stop model deployment")); verify(mockModelRegistry).getModel(eq(inferenceEndpointId), any()); verify(mockInferenceServiceRegistry).getService(eq(serviceName)); - verify(mockModelRegistry).containsDefaultConfigId(eq(inferenceEndpointId)); + verify(mockModelRegistry).containsPreconfiguredInferenceEndpointId(eq(inferenceEndpointId)); verify(mockService).parsePersistedConfig(eq(inferenceEndpointId), eq(taskType), any()); verify(mockService).stop(eq(mockModel), any()); verifyNoMoreInteractions(mockModelRegistry, mockInferenceServiceRegistry, mockService, mockModel); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java index d56b3fd8037c7..6d25b37649772 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java @@ -509,7 +509,7 @@ private XContentSource executeAction() throws ExecutionException, InterruptedExc private void givenDefaultEndpoints(String... ids) { for (String id : ids) { - when(modelRegistry.containsDefaultConfigId(id)).thenReturn(true); + when(modelRegistry.containsPreconfiguredInferenceEndpointId(id)).thenReturn(true); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java index 9af21386e93d3..5a59ab56efa88 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryMetadataTests.java @@ -24,6 +24,7 @@ import java.util.Set; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -314,4 +315,80 @@ public void testWithAddedModels_ReturnsNewMetadataInstance_ForNewInferenceId_Wit ) ); } + + public void testGetServiceInferenceIds_ReturnsCorrectIdsForKnownService() { + var serviceA = "service_a"; + var endpointId1 = "endpointId1"; + var endpointId2 = "endpointId2"; + + var settings1 = MinimalServiceSettings.chatCompletion(serviceA); + var settings2 = MinimalServiceSettings.sparseEmbedding(serviceA); + var models = Map.of(endpointId1, settings1, endpointId2, settings2); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + assertThat(serviceEndpoints, is(Set.of(endpointId1, endpointId2))); + } + + public void testGetServiceInferenceIds_AcceptsNullKeys() { + var serviceA = "service_a"; + var endpointId1 = "endpointId1"; + var endpointId2 = "endpointId2"; + var nullEndpoint1 = "nullEndpoint1"; + var nullEndpoint2 = "nullEndpoint2"; + + var settings1 = MinimalServiceSettings.chatCompletion(serviceA); + var settings2 = MinimalServiceSettings.sparseEmbedding(serviceA); + // I'm not sure why minimal service settings would have a null service name, but testing it anyway + var nullServiceNameSettings1 = MinimalServiceSettings.sparseEmbedding(null); + var nullServiceNameSettings2 = MinimalServiceSettings.sparseEmbedding(null); + var models = Map.of( + endpointId1, + settings1, + endpointId2, + settings2, + nullEndpoint1, + nullServiceNameSettings1, + nullEndpoint2, + nullServiceNameSettings2 + ); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + assertThat(serviceEndpoints, is(Set.of(endpointId1, endpointId2))); + assertThat(metadata.getServiceInferenceIds(null), is(Set.of(nullEndpoint1, nullEndpoint2))); + } + + public void testGetServiceInferenceIds_ReturnsEmptySetForUnknownService() { + var serviceA = "service_a"; + var serviceB = "service_b"; + var endpointId = "endpointId1"; + + var settings = MinimalServiceSettings.chatCompletion(serviceA); + var models = Map.of(endpointId, settings); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceB); + assertThat(serviceEndpoints, is(empty())); + } + + public void testGetServiceInferenceIds_ReturnsEmptySetForEmptyModelMap() { + var serviceA = "service_a"; + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.of()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + assertThat(serviceEndpoints, is(empty())); + } + + public void testGetServiceInferenceIds_ReturnedSetIsImmutable_WhenAttemptingToModifyIt() { + var serviceA = "service_a"; + var endpointId = "endpointId1"; + + var settings = MinimalServiceSettings.chatCompletion(serviceA); + var models = Map.of(endpointId, settings); + var metadata = new ModelRegistryMetadata(ImmutableOpenMap.builder(models).build()); + + var serviceEndpoints = metadata.getServiceInferenceIds(serviceA); + expectThrows(UnsupportedOperationException.class, () -> serviceEndpoints.add("newId")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java index 44f0dcc1d8962..a54eb379a054c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/registry/ModelRegistryTests.java @@ -90,15 +90,15 @@ public void testIdMatchedDefault() { assertFalse(matched.isPresent()); } - public void testContainsDefaultConfigId() { + public void testContainsPreconfiguredInferenceEndpointId() { registry.addDefaultIds( new InferenceService.DefaultConfigId("foo", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) ); registry.addDefaultIds( new InferenceService.DefaultConfigId("bar", MinimalServiceSettings.sparseEmbedding("my_service"), mock(InferenceService.class)) ); - assertTrue(registry.containsDefaultConfigId("foo")); - assertFalse(registry.containsDefaultConfigId("baz")); + assertTrue(registry.containsPreconfiguredInferenceEndpointId("foo")); + assertFalse(registry.containsPreconfiguredInferenceEndpointId("baz")); } public void testTaskTypeMatchedDefaults() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 8a23057195f4e..4b17cab04471a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -17,16 +17,13 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; -import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnifiedCompletionRequest; @@ -49,12 +46,10 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.services.InferenceEventsAssertion; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModel; import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationModelTests; -import org.elasticsearch.xpack.inference.services.elastic.authorization.ElasticInferenceServiceAuthorizationRequestHandler; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsModelTests; @@ -98,7 +93,6 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.isA; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -110,7 +104,6 @@ public class ElasticInferenceServiceTests extends ESSingleNodeTestCase { private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); - private ModelRegistry modelRegistry; private ThreadPool threadPool; private HttpClientManager clientManager; @@ -123,7 +116,6 @@ protected Collection> getPlugins() { @Before public void init() throws Exception { webServer.start(); - modelRegistry = node().injector().getInstance(ModelRegistry.class); threadPool = createThreadPool(inferenceUtilityExecutors()); clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); } @@ -921,8 +913,6 @@ public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOExceptio public void testHideFromConfigurationApi_ThrowsUnsupported_WithNoAvailableModels() throws Exception { try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } @@ -942,119 +932,86 @@ public void testHideFromConfigurationApi_ThrowsUnsupported_WithAvailableModels() ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::hideFromConfigurationApi); } } public void testCreateConfiguration() throws Exception { - try ( - var service = createServiceWithMockSender( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "model-1", - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ) - ) - ) - ) - ) - ) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], - "configurations": { - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": ["sparse_embedding", "chat_completion", "text_embedding"], + "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + boolean humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( + EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION, TaskType.TEXT_EMBEDDING) + ); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_WithoutSupportedTaskTypes() throws Exception { - try (var service = createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel.newDisabledService())) { - ensureAuthorizationCallFinished(service); - - String content = XContentHelper.stripWhitespace(""" - { - "service": "elastic", - "name": "Elastic", - "task_types": [], - "configurations": { - "model_id": { - "description": "The name of the model to use for the inference task.", - "label": "Model ID", - "required": true, - "sensitive": false, - "updatable": false, - "type": "str", - "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] - }, - "max_input_tokens": { - "description": "Allows you to specify the maximum number of tokens per input.", - "label": "Maximum Input Tokens", - "required": false, - "sensitive": false, - "updatable": false, - "type": "int", - "supported_task_types": ["text_embedding", "sparse_embedding"] - } + String content = XContentHelper.stripWhitespace(""" + { + "service": "elastic", + "name": "Elastic", + "task_types": [], + "configurations": { + "model_id": { + "description": "The name of the model to use for the inference task.", + "label": "Model ID", + "required": true, + "sensitive": false, + "updatable": false, + "type": "str", + "supported_task_types": ["text_embedding", "sparse_embedding" , "rerank", "chat_completion"] + }, + "max_input_tokens": { + "description": "Allows you to specify the maximum number of tokens per input.", + "label": "Maximum Input Tokens", + "required": false, + "sensitive": false, + "updatable": false, + "type": "int", + "supported_task_types": ["text_embedding", "sparse_embedding"] } } - """); - InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( - new BytesArray(content), - XContentType.JSON - ); - boolean humanReadable = true; - BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); - InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration( - EnumSet.noneOf(TaskType.class) - ); - assertToXContentEquivalent( - originalBytes, - toXContent(serviceConfiguration, XContentType.JSON, humanReadable), - XContentType.JSON - ); - } + } + """); + InferenceServiceConfiguration configuration = InferenceServiceConfiguration.fromXContentBytes( + new BytesArray(content), + XContentType.JSON + ); + var humanReadable = true; + BytesReference originalBytes = toShuffledXContent(configuration, XContentType.JSON, ToXContent.EMPTY_PARAMS, humanReadable); + InferenceServiceConfiguration serviceConfiguration = ElasticInferenceService.createConfiguration(EnumSet.noneOf(TaskType.class)); + assertToXContentEquivalent(originalBytes, toXContent(serviceConfiguration, XContentType.JSON, humanReadable), XContentType.JSON); } public void testGetConfiguration_ThrowsUnsupported() throws Exception { @@ -1073,30 +1030,13 @@ public void testGetConfiguration_ThrowsUnsupported() throws Exception { ) ) ) { - ensureAuthorizationCallFinished(service); - expectThrows(UnsupportedOperationException.class, service::getConfiguration); } } - public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWithAValidModel() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse", "chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedStreamingTasks_ReturnsChatCompletion() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - + try (var service = createService(senderFactory)) { assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); assertFalse(service.canStream(TaskType.ANY)); assertTrue(service.defaultConfigIds().isEmpty()); @@ -1107,79 +1047,10 @@ public void testSupportedStreamingTasks_ReturnsChatCompletion_WhenAuthRespondsWi } } - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes_IgnoresUnimplementedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["embed"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - } - } - - public void testSupportedTaskTypes_Returns_TheAuthorizedTaskTypes() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "model-b", - "task_types": ["chat"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - } - } - - public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChatCompletion() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "model-a", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testDefaultConfigs_ReturnsEmptyLists() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); + try (var service = createService(senderFactory)) { assertTrue(service.defaultConfigIds().isEmpty()); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); @@ -1187,120 +1058,10 @@ public void testSupportedStreamingTasks_ReturnsEmpty_WhenAuthRespondsWithoutChat } } - public void testDefaultConfigs_Returns_DefaultChatCompletion_V1_WhenTaskTypeIsIncorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["embed/text/sparse"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - + public void testSupportedTaskTypes_Returns_Unsupported() throws Exception { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - } - - public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() throws Exception { - String responseJson = """ - { - "models": [ - { - "model_name": "rainbow-sprinkles", - "task_types": ["chat"] - }, - { - "model_name": "elser_model_2", - "task_types": ["embed/text/sparse"] - }, - { - "model_name": "jina-embeddings-v3", - "task_types": ["embed/text/dense"] - }, - { - "model_name": "elastic-rerank-v1", - "task_types": ["rerank/text/text-similarity"] - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - try (var service = createServiceWithAuthHandler(senderFactory, getUrl(webServer))) { - ensureAuthorizationCallFinished(service); - assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertFalse(service.canStream(TaskType.ANY)); - assertThat( - service.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".elastic-rerank-v1", - MinimalServiceSettings.rerank(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".elser-2-elastic", - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), - service - ), - new InferenceService.DefaultConfigId( - ".jina-embeddings-v3", - MinimalServiceSettings.textEmbedding( - ElasticInferenceService.NAME, - ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS, - ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(), - DenseVectorFieldMapper.ElementType.FLOAT - ), - service - ), - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - service - ) - ) - ) - ); - assertThat( - service.supportedTaskTypes(), - is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING)) - ); - - PlainActionFuture> listener = new PlainActionFuture<>(); - service.defaultConfigs(listener); - var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(4)); - assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elastic-rerank-v1")); - assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".elser-2-elastic")); - assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".jina-embeddings-v3")); - assertThat(models.get(3).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + try (var service = createService(senderFactory)) { + expectThrows(UnsupportedOperationException.class, service::supportedTaskTypes); } } @@ -1392,23 +1153,11 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp } } - private void ensureAuthorizationCallFinished(ElasticInferenceService service) { - service.onNodeStarted(); - service.waitForFirstAuthorizationToComplete(TIMEOUT); - } - private ElasticInferenceService createServiceWithMockSender() { return createServiceWithMockSender(ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth()); } private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServiceAuthorizationModel auth) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(auth); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - var sender = createMockSender(); var factory = mock(HttpRequestSender.Factory.class); @@ -1417,52 +1166,19 @@ private ElasticInferenceService createServiceWithMockSender(ElasticInferenceServ factory, createWithEmptySettings(threadPool), new ElasticInferenceServiceSettings(Settings.EMPTY), - modelRegistry, - mockAuthHandler, mockClusterServiceEmpty() ); } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory) { - return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), null); + return createService(senderFactory, null); } private ElasticInferenceService createService(HttpRequestSender.Factory senderFactory, String elasticInferenceServiceURL) { - return createService(senderFactory, ElasticInferenceServiceAuthorizationModelTests.createEnabledAuth(), elasticInferenceServiceURL); - } - - private ElasticInferenceService createService( - HttpRequestSender.Factory senderFactory, - ElasticInferenceServiceAuthorizationModel auth, - String elasticInferenceServiceURL - ) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(auth); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - return new ElasticInferenceService( - senderFactory, - createWithEmptySettings(threadPool), - ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), - modelRegistry, - mockAuthHandler, - mockClusterServiceEmpty() - ); - } - - private ElasticInferenceService createServiceWithAuthHandler( - HttpRequestSender.Factory senderFactory, - String elasticInferenceServiceURL - ) { return new ElasticInferenceService( senderFactory, createWithEmptySettings(threadPool), ElasticInferenceServiceSettingsTests.create(elasticInferenceServiceURL), - modelRegistry, - new ElasticInferenceServiceAuthorizationRequestHandler(elasticInferenceServiceURL, threadPool), mockClusterServiceEmpty() ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java new file mode 100644 index 0000000000000..d0d3a67b2d9d5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationPollerTests.java @@ -0,0 +1,396 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.tasks.TaskManager; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.action.StoreInferenceEndpointsAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; +import org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints; +import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.EnumSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AuthorizationPollerTests extends ESTestCase { + private DeterministicTaskQueue taskQueue; + + @Before + public void init() throws Exception { + taskQueue = new DeterministicTaskQueue(); + } + + public void testDoesNotSendAuthorizationRequest_WhenModelRegistryIsNotReady() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(false); + + var authorizationRequestHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + authorizationRequestHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mock(Client.class), + null + ); + + poller.sendAuthorizationRequest(); + + verify(authorizationRequestHandler, never()).getAuthorization(any(), any()); + } + + public void testSendsAuthorizationRequest_WhenModelRegistryIsReady() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mockClient, + null + ); + + var requestArgCaptor = ArgumentCaptor.forClass(StoreInferenceEndpointsAction.Request.class); + + poller.sendAuthorizationRequest(); + verify(mockClient).execute(eq(StoreInferenceEndpointsAction.INSTANCE), requestArgCaptor.capture(), any()); + var capturedRequest = requestArgCaptor.getValue(); + assertThat( + capturedRequest.getModels(), + is( + List.of( + PreconfiguredEndpointModelAdapter.createModel( + InternalPreconfiguredEndpoints.getWithInferenceId(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2), + new ElasticInferenceServiceComponents("") + ) + ) + ) + ); + } + + public void testSendsAuthorizationRequest_ButDoesNotStoreAnyModels_WhenTheirInferenceIdAlreadyExists() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of(InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2, "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mockClient, + null + ); + + poller.sendAuthorizationRequest(); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + } + + public void testDoesNotAttemptToStoreModelIds_ThatDoNotExistInThePreconfiguredMapping() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + // This is a model id that does not exist in the preconfigured endpoints map so it will not be stored + "abc", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mockClient, + null + ); + + poller.sendAuthorizationRequest(); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + } + + public void testDoesNotAttemptToStoreModelIds_ThatHaveATaskTypeThatTheEISIntegration_DoesNotSupport() { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID, + // EIS does not yet support completions so this model will be ignored + EnumSet.of(TaskType.COMPLETION) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + when(mockClient.threadPool()).thenReturn(taskQueue.getThreadPool()); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mockClient, + null + ); + + poller.sendAuthorizationRequest(); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + } + + public void testSendsTwoAuthorizationRequests() throws InterruptedException { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + // this is an unknown model id so it won't trigger storing an inference endpoint because + // it doesn't map to a known one + "abc", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + + var callbackCount = new AtomicInteger(0); + var latch = new CountDownLatch(2); + final var pollerRef = new AtomicReference(); + + Runnable callback = () -> { + var count = callbackCount.incrementAndGet(); + latch.countDown(); + + // we only want to run the tasks twice, so advance the time on the queue + // which flags the scheduled authorization request to be ready to run + if (count == 1) { + taskQueue.advanceTime(); + } else { + pollerRef.get().shutdown(); + } + }; + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mockRegistry, + mockClient, + callback + ); + pollerRef.set(poller); + poller.start(); + taskQueue.runAllRunnableTasks(); + latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS); + + assertThat(callbackCount.get(), is(2)); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + } + + public void testCallsShutdownAndMarksTaskAsCompleted_WhenSchedulingFails() throws InterruptedException { + var mockRegistry = mock(ModelRegistry.class); + when(mockRegistry.isReady()).thenReturn(true); + when(mockRegistry.getInferenceIds()).thenReturn(Set.of("id1", "id2")); + + var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse( + ElasticInferenceServiceAuthorizationModel.of( + new ElasticInferenceServiceAuthorizationResponseEntity( + List.of( + new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( + // this is an unknown model id so it won't trigger storing an inference endpoint because + // it doesn't map to a known one + "abc", + EnumSet.of(TaskType.SPARSE_EMBEDDING) + ) + ) + ) + ) + ); + return Void.TYPE; + }).when(mockAuthHandler).getAuthorization(any(), any()); + + var mockClient = mock(Client.class); + + var callbackCount = new AtomicInteger(0); + var latch = new CountDownLatch(1); + + Runnable callback = () -> { + callbackCount.incrementAndGet(); + latch.countDown(); + }; + + var exception = new IllegalStateException("failing"); + // Simulate scheduling failure by having the settings throw an exception when queried + // Throwing an exception should cause the poller to shutdown and mark itself as completed + var settingsMock = mock(ElasticInferenceServiceSettings.class); + when(settingsMock.isPeriodicAuthorizationEnabled()).thenThrow(exception); + + var poller = new AuthorizationPoller( + new AuthorizationPoller.TaskFields(0, "abc", "abc", "abc", new TaskId("abc", 0), Map.of()), + createWithEmptySettings(taskQueue.getThreadPool()), + mockAuthHandler, + mock(Sender.class), + settingsMock, + mockRegistry, + mockClient, + callback + ); + + var persistentTaskId = "id"; + var allocationId = 0L; + + var mockPersistentTasksService = mock(PersistentTasksService.class); + poller.init(mockPersistentTasksService, mock(TaskManager.class), persistentTaskId, allocationId); + poller.start(); + taskQueue.runAllRunnableTasks(); + latch.await(TimeValue.THIRTY_SECONDS.getSeconds(), TimeUnit.SECONDS); + + assertThat(callbackCount.get(), is(1)); + assertTrue(poller.isShutdown()); + verify(mockPersistentTasksService, times(1)).sendCompletionRequest( + eq(persistentTaskId), + eq(allocationId), + eq(exception), + eq(null), + any(), + any() + ); + verify(mockClient, never()).execute(eq(StoreInferenceEndpointsAction.INSTANCE), any(), any()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java new file mode 100644 index 0000000000000..15b586d62890d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskExecutorTests.java @@ -0,0 +1,218 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.cluster.node.DiscoveryNodes; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.persistent.ClusterPersistentTasksCustomMetadata; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; +import org.junit.After; +import org.junit.Before; +import org.mockito.Mockito; + +import static org.elasticsearch.cluster.metadata.Metadata.EMPTY_METADATA; +import static org.elasticsearch.persistent.PersistentTasksExecutor.NO_NODE_FOUND; +import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityExecutors; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class AuthorizationTaskExecutorTests extends ESTestCase { + + private ThreadPool threadPool; + private ClusterService clusterService; + private PersistentTasksService persistentTasksService; + private String localNodeId; + + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool(inferenceUtilityExecutors()); + clusterService = createClusterService(threadPool); + persistentTasksService = mock(PersistentTasksService.class); + localNodeId = clusterService.localNode().getId(); + } + + @After + public void tearDown() throws Exception { + super.tearDown(); + clusterService.close(); + terminate(threadPool); + } + + public void testCreatesTask_WhenItDoesNotExistOnClusterStateChange() { + var eisUrl = "abc"; + + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + var listener1 = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener1); + listener1.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, times(1)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + + Mockito.clearInvocations(persistentTasksService); + // Ensure that if the task is gone, it will be recreated. + var listener2 = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener2); + listener2.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, times(1)).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(new AuthorizationTaskParams()), + any(), + any() + ); + } + + private ClusterState initialState() { + DiscoveryNodes.Builder nodes = DiscoveryNodes.builder() + .add(DiscoveryNodeUtils.create(localNodeId)) + .localNodeId(localNodeId) + .masterNodeId(localNodeId); + + return ClusterState.builder(ClusterName.DEFAULT).nodes(nodes).metadata(EMPTY_METADATA).build(); + } + + public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEisUrlIsEmpty() { + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create("", TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + var listener = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + public void testDoesNotRegisterClusterStateListener_DoesNotCreateTask_WhenTheEisUrlIsNull() { + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + var listener = new PlainActionFuture(); + clusterService.getClusterApplierService().onNewClusterState("initialization", this::initialState, listener); + listener.actionGet(TimeValue.THIRTY_SECONDS); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } + + public void testDoesNotCreateTask_OnClusterStateChange_WhenItAlreadyExists() { + var initialState = initialState(); + var event = new ClusterChangedEvent( + "testClusterChanged", + ClusterState.builder(initialState) + .metadata( + Metadata.builder(initialState.metadata()) + .putCustom( + ClusterPersistentTasksCustomMetadata.TYPE, + ClusterPersistentTasksCustomMetadata.builder() + .addTask( + AuthorizationPoller.TASK_NAME, + AuthorizationPoller.TASK_NAME, + AuthorizationTaskParams.INSTANCE, + NO_NODE_FOUND + ) + .build() + ) + ) + .build(), + ClusterState.EMPTY_STATE + ); + + var eisUrl = "abc"; + var executor = new AuthorizationTaskExecutor( + clusterService, + persistentTasksService, + new AuthorizationPoller.Parameters( + createWithEmptySettings(threadPool), + mock(ElasticInferenceServiceAuthorizationRequestHandler.class), + mock(Sender.class), + ElasticInferenceServiceSettingsTests.create(eisUrl, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), + mock(ModelRegistry.class), + mock(Client.class) + ) + ); + executor.init(); + + executor.clusterChanged(event); + verify(persistentTasksService, never()).sendClusterStartRequest( + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationPoller.TASK_NAME), + eq(AuthorizationTaskParams.INSTANCE), + any(), + any() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java new file mode 100644 index 0000000000000..e57bcb6c99a49 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/AuthorizationTaskParamsTests.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +public class AuthorizationTaskParamsTests extends AbstractBWCWireSerializationTestCase { + + @Override + protected AuthorizationTaskParams mutateInstanceForVersion(AuthorizationTaskParams instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AuthorizationTaskParams::new; + } + + @Override + protected AuthorizationTaskParams createTestInstance() { + return new AuthorizationTaskParams(); + } + + @Override + protected AuthorizationTaskParams mutateInstance(AuthorizationTaskParams instance) throws IOException { + // need to return null here because the instances will always be identical + return null; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java deleted file mode 100644 index fd7bf5c4c56c4..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ /dev/null @@ -1,283 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.elastic.authorization; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; -import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.EmptySecretSettings; -import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InferenceService; -import org.elasticsearch.inference.MinimalServiceSettings; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; -import org.elasticsearch.xpack.inference.LocalStateInferencePlugin; -import org.elasticsearch.xpack.inference.Utils; -import org.elasticsearch.xpack.inference.external.http.sender.Sender; -import org.elasticsearch.xpack.inference.registry.ModelRegistry; -import org.elasticsearch.xpack.inference.services.elastic.DefaultModelConfig; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSettingsTests; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; -import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; -import org.elasticsearch.xpack.inference.services.elastic.response.ElasticInferenceServiceAuthorizationResponseEntity; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; -import org.junit.Before; - -import java.io.IOException; -import java.util.Collection; -import java.util.EnumSet; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; - -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; -import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.defaultEndpointId; -import static org.hamcrest.CoreMatchers.is; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; - -public class ElasticInferenceServiceAuthorizationHandlerTests extends ESSingleNodeTestCase { - private DeterministicTaskQueue taskQueue; - private ModelRegistry modelRegistry; - - @Override - protected Collection> getPlugins() { - return List.of(LocalStateInferencePlugin.class); - } - - @Before - public void init() throws Exception { - taskQueue = new DeterministicTaskQueue(); - modelRegistry = getInstanceFromNode(ModelRegistry.class); - } - - public void testSecondAuthResultRevokesAuthorization() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response contains a streaming task so we're expecting to support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of(new ElasticInferenceServiceAuthorizationResponseEntity(List.of())) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - - // this should be after we've received both authorization responses, the second response will revoke authorization - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - assertThat(handler.defaultConfigIds(), is(List.of())); - assertThat(handler.supportedTaskTypes(), is(EnumSet.noneOf(TaskType.class))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.size(), is(0)); - } - - public void testSendsAnAuthorizationRequestTwice() throws Exception { - var callbackCount = new AtomicInteger(0); - // we're only interested in two authorization calls which is why I'm using a value of 2 here - var latch = new CountDownLatch(2); - final AtomicReference handlerRef = new AtomicReference<>(); - - Runnable callback = () -> { - // the first authorization response does not contain a streaming task so we're expecting to not support streaming here - if (callbackCount.incrementAndGet() == 1) { - assertThat(handlerRef.get().supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class))); - } - latch.countDown(); - - // we only want to run the tasks twice, so advance the time on the queue - // which flags the scheduled authorization request to be ready to run - if (callbackCount.get() == 1) { - taskQueue.advanceTime(); - } else { - try { - handlerRef.get().close(); - } catch (IOException e) { - // ignore - } - } - }; - - var requestHandler = mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel("abc", EnumSet.of(TaskType.SPARSE_EMBEDDING)) - ) - ) - ), - ElasticInferenceServiceAuthorizationModel.of( - new ElasticInferenceServiceAuthorizationResponseEntity( - List.of( - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "abc", - EnumSet.of(TaskType.SPARSE_EMBEDDING) - ), - new ElasticInferenceServiceAuthorizationResponseEntity.AuthorizedModel( - "rainbow-sprinkles", - EnumSet.of(TaskType.CHAT_COMPLETION) - ) - ) - ) - ) - ); - - handlerRef.set( - new ElasticInferenceServiceAuthorizationHandler( - createWithEmptySettings(taskQueue.getThreadPool()), - modelRegistry, - requestHandler, - initDefaultEndpoints(), - EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION), - null, - mock(Sender.class), - ElasticInferenceServiceSettingsTests.create(null, TimeValue.timeValueMillis(1), TimeValue.timeValueMillis(1), true), - callback - ) - ); - - var handler = handlerRef.get(); - handler.init(); - taskQueue.runAllRunnableTasks(); - latch.await(Utils.TIMEOUT.getSeconds(), TimeUnit.SECONDS); - // this should be after we've received both authorization responses - - assertThat(handler.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION))); - assertThat( - handler.defaultConfigIds(), - is( - List.of( - new InferenceService.DefaultConfigId( - ".rainbow-sprinkles-elastic", - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), - null - ) - ) - ) - ); - assertThat(handler.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.CHAT_COMPLETION))); - - PlainActionFuture> listener = new PlainActionFuture<>(); - handler.defaultConfigs(listener); - - var configs = listener.actionGet(); - assertThat(configs.get(0).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); - } - - private static ElasticInferenceServiceAuthorizationRequestHandler mockAuthorizationRequestHandler( - ElasticInferenceServiceAuthorizationModel firstAuthResponse, - ElasticInferenceServiceAuthorizationModel secondAuthResponse - ) { - var mockAuthHandler = mock(ElasticInferenceServiceAuthorizationRequestHandler.class); - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(firstAuthResponse); - return Void.TYPE; - }).doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(0); - listener.onResponse(secondAuthResponse); - return Void.TYPE; - }).when(mockAuthHandler).getAuthorization(any(), any()); - - return mockAuthHandler; - } - - private static Map initDefaultEndpoints() { - return Map.of( - "rainbow-sprinkles", - new DefaultModelConfig( - new ElasticInferenceServiceCompletionModel( - defaultEndpointId("rainbow-sprinkles"), - TaskType.CHAT_COMPLETION, - "test", - new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE - ), - MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME) - ), - "elser-2", - new DefaultModelConfig( - new ElasticInferenceServiceSparseEmbeddingsModel( - defaultEndpointId("elser-2"), - TaskType.SPARSE_EMBEDDING, - "test", - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null), - EmptyTaskSettings.INSTANCE, - EmptySecretSettings.INSTANCE, - ElasticInferenceServiceComponents.EMPTY_INSTANCE, - ChunkingSettingsBuilder.DEFAULT_SETTINGS - ), - MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME) - ) - ); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java new file mode 100644 index 0000000000000..e718c83c3f965 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/PreconfiguredEndpointModelAdapterTests.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.elastic.authorization; + +import org.elasticsearch.inference.EmptySecretSettings; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.chunking.ChunkingSettingsBuilder; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; +import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceModel; +import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.densetextembeddings.ElasticInferenceServiceDenseTextEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_CHAT_COMPLETION_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_2_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_ELSER_ENDPOINT_ID_V2; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_MULTILINGUAL_EMBED_MODEL_ID; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_RERANK_ENDPOINT_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DEFAULT_RERANK_MODEL_ID_V1; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.DENSE_TEXT_EMBEDDINGS_DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.elastic.InternalPreconfiguredEndpoints.defaultDenseTextEmbeddingsSimilarity; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; + +public class PreconfiguredEndpointModelAdapterTests extends ESTestCase { + + private static final ElasticInferenceServiceSparseEmbeddingsServiceSettings SPARSE_SETTINGS = + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null); + private static final ElasticInferenceServiceCompletionServiceSettings COMPLETION_SETTINGS = + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1); + private static final ElasticInferenceServiceDenseTextEmbeddingsServiceSettings DENSE_SETTINGS = + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( + DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, + defaultDenseTextEmbeddingsSimilarity(), + DENSE_TEXT_EMBEDDINGS_DIMENSIONS, + null + ); + private static final ElasticInferenceServiceRerankServiceSettings RERANK_SETTINGS = new ElasticInferenceServiceRerankServiceSettings( + DEFAULT_RERANK_MODEL_ID_V1 + ); + private static final ElasticInferenceServiceComponents EIS_COMPONENTS = new ElasticInferenceServiceComponents(""); + + public void testGetModelsWithValidId() { + var endpointIds = Set.of( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + DEFAULT_ELSER_ENDPOINT_ID_V2, + DEFAULT_RERANK_ENDPOINT_ID_V1, + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID + ); + var models = PreconfiguredEndpointModelAdapter.getModels(endpointIds, EIS_COMPONENTS); + + assertThat(models, hasSize(endpointIds.size())); + assertThat( + models, + containsInAnyOrder( + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + SPARSE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + SPARSE_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + COMPLETION_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + COMPLETION_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_MULTILINGUAL_EMBED_ENDPOINT_ID, + TaskType.TEXT_EMBEDDING, + ElasticInferenceService.NAME, + DENSE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + DENSE_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + ElasticInferenceService.NAME, + RERANK_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + RERANK_SETTINGS, + EIS_COMPONENTS + ) + ) + ); + } + + public void testGetModelsWithValidAndInvalidIds() { + var models = PreconfiguredEndpointModelAdapter.getModels( + Set.of(DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, "some-invalid-id", DEFAULT_ELSER_ENDPOINT_ID_V2), + EIS_COMPONENTS + ); + + assertThat(models, hasSize(2)); + assertThat( + models, + containsInAnyOrder( + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_ELSER_ENDPOINT_ID_V2, + TaskType.SPARSE_EMBEDDING, + ElasticInferenceService.NAME, + SPARSE_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + SPARSE_SETTINGS, + EIS_COMPONENTS + ), + new ElasticInferenceServiceModel( + new ModelConfigurations( + DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, + TaskType.CHAT_COMPLETION, + ElasticInferenceService.NAME, + COMPLETION_SETTINGS, + ChunkingSettingsBuilder.DEFAULT_SETTINGS + ), + new ModelSecrets(EmptySecretSettings.INSTANCE), + COMPLETION_SETTINGS, + EIS_COMPONENTS + ) + ) + ); + } + + public void testGetModelsWithOnlyInvalidId() { + assertThat(PreconfiguredEndpointModelAdapter.getModels(Collections.singleton("nonexistent-id"), EIS_COMPONENTS), is(List.of())); + } +} diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 78f2b6523fc78..0957ec55e882a 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -174,7 +174,6 @@ public class Constants { "cluster:admin/xpack/enrich/get", "cluster:admin/xpack/enrich/put", "cluster:admin/xpack/enrich/reindex", - "cluster:internal/xpack/inference/clear_inference_endpoint_cache", "cluster:admin/xpack/inference/ccm/delete", "cluster:admin/xpack/inference/ccm/put", "cluster:admin/xpack/inference/delete", @@ -329,6 +328,8 @@ public class Constants { "cluster:admin/xpack/watcher/watch/put", "cluster:internal/remote_cluster/nodes", "cluster:internal/xpack/inference", + "cluster:internal/xpack/inference/clear_inference_endpoint_cache", + "cluster:internal/xpack/inference/create_endpoints", "cluster:internal/xpack/inference/rerankwindowsize/get", "cluster:internal/xpack/inference/unified", "cluster:internal/xpack/ml/auditor/reset",