diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5ffc32339d8cf..2d6c47eefffaf 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -358,6 +358,7 @@ static TransportVersion def(int id) { public static final TransportVersion RESOLVE_INDEX_MODE_FILTER = def(9_149_0_00); public static final TransportVersion SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS = def(9_150_0_00); public static final TransportVersion ESQL_LOOKUP_JOIN_PRE_JOIN_FILTER = def(9_151_0_00); + public static final TransportVersion INFERENCE_API_DISABLE_EIS_RATE_LIMITING = def(9_152_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index e3fff14bf95d7..55b59e3fd1d9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -449,9 +449,11 @@ public synchronized TimeValue executeEnqueuedTask() { } private TimeValue executeEnqueuedTaskInternal() { - var timeBeforeAvailableToken = rateLimiter.timeToReserve(1); - if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) { - return timeBeforeAvailableToken; + if (rateLimitSettings.isEnabled()) { + var timeBeforeAvailableToken = rateLimiter.timeToReserve(1); + if (shouldExecuteImmediately(timeBeforeAvailableToken) == false) { + return timeBeforeAvailableToken; + } } var task = queue.poll(); @@ -463,9 +465,11 @@ private TimeValue executeEnqueuedTaskInternal() { return NO_TASKS_AVAILABLE; } - // We should never have to wait because we checked above - var reserveRes = rateLimiter.reserve(1); - assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have"; + if (rateLimitSettings.isEnabled()) { + // We should never have to wait because we checked above + var reserveRes = rateLimiter.reserve(1); + assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have"; + } task.getRequestManager() .execute(task.getInferenceInputs(), requestSender, task.getRequestCompletedFunction(), task.getListener()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 62b01c779db33..69ae769e36dc4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -193,7 +193,7 @@ private static Map initDefaultEndpoints( DEFAULT_CHAT_COMPLETION_ENDPOINT_ID_V1, TaskType.CHAT_COMPLETION, NAME, - new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1, null), + new ElasticInferenceServiceCompletionServiceSettings(DEFAULT_CHAT_COMPLETION_MODEL_ID_V1), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents @@ -206,7 +206,7 @@ private static Map initDefaultEndpoints( DEFAULT_ELSER_ENDPOINT_ID_V2, TaskType.SPARSE_EMBEDDING, NAME, - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null, null), + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(DEFAULT_ELSER_2_MODEL_ID, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents, @@ -224,8 +224,7 @@ private static Map initDefaultEndpoints( DEFAULT_MULTILINGUAL_EMBED_MODEL_ID, defaultDenseTextEmbeddingsSimilarity(), null, - null, - ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS + null ), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, @@ -245,7 +244,7 @@ private static Map initDefaultEndpoints( DEFAULT_RERANK_ENDPOINT_ID_V1, TaskType.RERANK, NAME, - new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null), + new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, elasticInferenceServiceComponents @@ -622,8 +621,7 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { modelId, similarityToUse, embeddingSize, - maxInputTokens, - serviceSettings.rateLimitSettings() + maxInputTokens ); return new ElasticInferenceServiceDenseTextEmbeddingsModel(embeddingsModel, updateServiceSettings); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java index 5125ade21339d..969bf06d47fe0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModel.java @@ -35,8 +35,7 @@ public static ElasticInferenceServiceCompletionModel of( ) { var originalModelServiceSettings = model.getServiceSettings(); var overriddenServiceSettings = new ElasticInferenceServiceCompletionServiceSettings( - Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), - originalModelServiceSettings.rateLimitSettings() + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()) ); return new ElasticInferenceServiceCompletionModel(model, overriddenServiceSettings); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java index f062a57d03f82..58da188fa2bb5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettings.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; @@ -35,38 +36,41 @@ public class ElasticInferenceServiceCompletionServiceSettings extends FilteredXC public static final String NAME = "elastic_inference_service_completion_service_settings"; - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(720L); - public static ElasticInferenceServiceCompletionServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( + + RateLimitSettings.rejectRateLimitFieldForRequestContext( map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, + ModelConfigurations.SERVICE_SETTINGS, ElasticInferenceService.NAME, - context + TaskType.CHAT_COMPLETION, + context, + validationException ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new ElasticInferenceServiceCompletionServiceSettings(modelId, rateLimitSettings); + return new ElasticInferenceServiceCompletionServiceSettings(modelId); } private final String modelId; private final RateLimitSettings rateLimitSettings; - public ElasticInferenceServiceCompletionServiceSettings(String modelId, RateLimitSettings rateLimitSettings) { + public ElasticInferenceServiceCompletionServiceSettings(String modelId) { this.modelId = Objects.requireNonNull(modelId); - this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; } public ElasticInferenceServiceCompletionServiceSettings(StreamInput in) throws IOException { this.modelId = in.readString(); - this.rateLimitSettings = new RateLimitSettings(in); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; + if (in.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + new RateLimitSettings(in); + } } @Override @@ -110,7 +114,9 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); - rateLimitSettings.writeTo(out); + if (out.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + rateLimitSettings.writeTo(out); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java index e8eeee5a34dd4..15de05e004490 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.java @@ -17,6 +17,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; @@ -43,8 +44,6 @@ public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends F public static final String NAME = "elastic_inference_service_dense_embeddings_service_settings"; - public static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000); - private final String modelId; private final SimilarityMeasure similarity; private final Integer dimensions; @@ -54,77 +53,41 @@ public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettings extends F public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromMap( Map map, ConfigurationParseContext context - ) { - return switch (context) { - case REQUEST -> fromRequestMap(map, context); - case PERSISTENT -> fromPersistentMap(map, context); - }; - } - - private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromRequestMap( - Map map, - ConfigurationParseContext context ) { ValidationException validationException = new ValidationException(); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( - map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, - ElasticInferenceService.NAME, - context - ); - SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); Integer dims = removeAsType(map, DIMENSIONS, Integer.class); Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings); - } - - private static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings fromPersistentMap( - Map map, - ConfigurationParseContext context - ) { - ValidationException validationException = new ValidationException(); - - String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( + RateLimitSettings.rejectRateLimitFieldForRequestContext( map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, + ModelConfigurations.SERVICE_SETTINGS, ElasticInferenceService.NAME, - context + TaskType.TEXT_EMBEDDING, + context, + validationException ); - SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); - Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens, rateLimitSettings); + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dims, maxInputTokens); } public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( String modelId, @Nullable SimilarityMeasure similarity, @Nullable Integer dimensions, - @Nullable Integer maxInputTokens, - RateLimitSettings rateLimitSettings + @Nullable Integer maxInputTokens ) { this.modelId = modelId; this.similarity = similarity; this.dimensions = dimensions; this.maxInputTokens = maxInputTokens; - this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; } public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) throws IOException { @@ -132,7 +95,11 @@ public ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(StreamInput in) this.similarity = in.readOptionalEnum(SimilarityMeasure.class); this.dimensions = in.readOptionalVInt(); this.maxInputTokens = in.readOptionalVInt(); - this.rateLimitSettings = new RateLimitSettings(in); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; + + if (in.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + new RateLimitSettings(in); + } } @Override @@ -221,7 +188,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); out.writeOptionalVInt(dimensions); out.writeOptionalVInt(maxInputTokens); - rateLimitSettings.writeTo(out); + if (out.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + rateLimitSettings.writeTo(out); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java index eff22c2771930..e6193487f57b2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettings.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; @@ -35,35 +36,42 @@ public class ElasticInferenceServiceRerankServiceSettings extends FilteredXConte public static final String NAME = "elastic_rerank_service_settings"; - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(500); - public static ElasticInferenceServiceRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( + + RateLimitSettings.rejectRateLimitFieldForRequestContext( map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, + ModelConfigurations.SERVICE_SETTINGS, ElasticInferenceService.NAME, - context + TaskType.RERANK, + context, + validationException ); - return new ElasticInferenceServiceRerankServiceSettings(modelId, rateLimitSettings); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new ElasticInferenceServiceRerankServiceSettings(modelId); } private final String modelId; private final RateLimitSettings rateLimitSettings; - public ElasticInferenceServiceRerankServiceSettings(String modelId, RateLimitSettings rateLimitSettings) { + public ElasticInferenceServiceRerankServiceSettings(String modelId) { this.modelId = Objects.requireNonNull(modelId); - this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; } public ElasticInferenceServiceRerankServiceSettings(StreamInput in) throws IOException { this.modelId = in.readString(); - this.rateLimitSettings = new RateLimitSettings(in); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; + if (in.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + new RateLimitSettings(in); + } } @Override @@ -115,7 +123,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); - rateLimitSettings.writeTo(out); + if (out.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + rateLimitSettings.writeTo(out); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java index 10395c430969b..831e3822e1764 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/sparseembeddings/ElasticInferenceServiceSparseEmbeddingsServiceSettings.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService; @@ -38,8 +39,6 @@ public class ElasticInferenceServiceSparseEmbeddingsServiceSettings extends Filt public static final String NAME = "elastic_inference_service_sparse_embeddings_service_settings"; - private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(1_000); - public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap( Map map, ConfigurationParseContext context @@ -54,19 +53,20 @@ public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap( validationException ); - RateLimitSettings rateLimitSettings = RateLimitSettings.of( + RateLimitSettings.rejectRateLimitFieldForRequestContext( map, - DEFAULT_RATE_LIMIT_SETTINGS, - validationException, + ModelConfigurations.SERVICE_SETTINGS, ElasticInferenceService.NAME, - context + TaskType.SPARSE_EMBEDDING, + context, + validationException ); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, rateLimitSettings); + return new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens); } private final String modelId; @@ -74,20 +74,19 @@ public static ElasticInferenceServiceSparseEmbeddingsServiceSettings fromMap( private final Integer maxInputTokens; private final RateLimitSettings rateLimitSettings; - public ElasticInferenceServiceSparseEmbeddingsServiceSettings( - String modelId, - @Nullable Integer maxInputTokens, - @Nullable RateLimitSettings rateLimitSettings - ) { + public ElasticInferenceServiceSparseEmbeddingsServiceSettings(String modelId, @Nullable Integer maxInputTokens) { this.modelId = Objects.requireNonNull(modelId); this.maxInputTokens = maxInputTokens; - this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; } public ElasticInferenceServiceSparseEmbeddingsServiceSettings(StreamInput in) throws IOException { this.modelId = in.readString(); this.maxInputTokens = in.readOptionalVInt(); - this.rateLimitSettings = new RateLimitSettings(in); + this.rateLimitSettings = RateLimitSettings.DISABLED_INSTANCE; + if (in.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + new RateLimitSettings(in); + } } @Override @@ -139,7 +138,9 @@ protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder buil public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeOptionalVInt(maxInputTokens); - rateLimitSettings.writeTo(out); + if (out.getTransportVersion().before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + rateLimitSettings.writeTo(out); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java index bc7e555120286..e7523ad79dcb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettings.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.inference.services.settings; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Strings; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; @@ -32,9 +34,7 @@ public class RateLimitSettings implements Writeable, ToXContentFragment { public static final String FIELD_NAME = "rate_limit"; public static final String REQUESTS_PER_MINUTE_FIELD = "requests_per_minute"; - - private final long requestsPerTimeUnit; - private final TimeUnit timeUnit; + public static final RateLimitSettings DISABLED_INSTANCE = new RateLimitSettings(1, TimeUnit.MINUTES, false); public static RateLimitSettings of( Map map, @@ -53,6 +53,26 @@ public static RateLimitSettings of( return requestsPerMinute == null ? defaultValue : new RateLimitSettings(requestsPerMinute); } + public static void rejectRateLimitFieldForRequestContext( + Map map, + String scope, + String service, + TaskType taskType, + ConfigurationParseContext context, + ValidationException validationException + ) { + if (ConfigurationParseContext.isRequestContext(context) && map.containsKey(FIELD_NAME)) { + validationException.addValidationError( + Strings.format( + "[%s] rate limit settings are not permitted for service [%s] and task type [%s]", + scope, + service, + taskType.toString() + ) + ); + } + } + public static Map toSettingsConfigurationWithDescription( String description, EnumSet supportedTaskTypes @@ -75,6 +95,10 @@ public static Map toSettingsConfiguration(EnumSet return RateLimitSettings.toSettingsConfigurationWithDescription("Minimize the number of rate limit errors.", supportedTaskTypes); } + private final long requestsPerTimeUnit; + private final TimeUnit timeUnit; + private final boolean enabled; + /** * Defines the settings in requests per minute * @param requestsPerMinute _ @@ -84,6 +108,8 @@ public RateLimitSettings(long requestsPerMinute) { } /** + * This should only be used for testing. + * * Defines the settings in requests per the time unit provided * @param requestsPerTimeUnit number of requests * @param timeUnit _ @@ -91,16 +117,27 @@ public RateLimitSettings(long requestsPerMinute) { * Note: The time unit is not serialized */ public RateLimitSettings(long requestsPerTimeUnit, TimeUnit timeUnit) { + this(requestsPerTimeUnit, timeUnit, true); + } + + // This should only be used for testing. + RateLimitSettings(long requestsPerTimeUnit, TimeUnit timeUnit, boolean enabled) { if (requestsPerTimeUnit <= 0) { throw new IllegalArgumentException("requests per minute must be positive"); } this.requestsPerTimeUnit = requestsPerTimeUnit; this.timeUnit = Objects.requireNonNull(timeUnit); + this.enabled = enabled; } public RateLimitSettings(StreamInput in) throws IOException { requestsPerTimeUnit = in.readVLong(); timeUnit = TimeUnit.MINUTES; + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + enabled = in.readBoolean(); + } else { + enabled = true; + } } public long requestsPerTimeUnit() { @@ -111,8 +148,16 @@ public TimeUnit timeUnit() { return timeUnit; } + public boolean isEnabled() { + return enabled; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + if (enabled == false) { + return builder; + } + builder.startObject(FIELD_NAME); builder.field(REQUESTS_PER_MINUTE_FIELD, requestsPerTimeUnit); builder.endObject(); @@ -122,6 +167,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public void writeTo(StreamOutput out) throws IOException { out.writeVLong(requestsPerTimeUnit); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + out.writeBoolean(enabled); + } } @Override @@ -129,11 +177,13 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; RateLimitSettings that = (RateLimitSettings) o; - return Objects.equals(requestsPerTimeUnit, that.requestsPerTimeUnit) && Objects.equals(timeUnit, that.timeUnit); + return Objects.equals(requestsPerTimeUnit, that.requestsPerTimeUnit) + && Objects.equals(timeUnit, that.timeUnit) + && enabled == that.enabled; } @Override public int hashCode() { - return Objects.hash(requestsPerTimeUnit, timeUnit); + return Objects.hash(requestsPerTimeUnit, timeUnit, enabled); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index 500a375063bd6..163c4b84f1780 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.inference.common.RateLimiter; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.RetryingHttpSender; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.junit.After; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -52,6 +53,7 @@ import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -569,6 +571,46 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() { verifyNoInteractions(requestSender); } + public void testDoesNotAttemptToReserveTokens_WhenRateLimitSettingsDisabled() throws ExecutionException, InterruptedException, + TimeoutException { + var mockRateLimiter = mock(RateLimiter.class); + RequestExecutorService.RateLimiterCreator rateLimiterCreator = (a, b, c) -> mockRateLimiter; + + var requestSender = mock(RetryingHttpSender.class); + var settings = createRequestExecutorServiceSettings(1); + var service = new RequestExecutorService( + threadPool, + RequestExecutorService.DEFAULT_QUEUE_CREATOR, + null, + settings, + requestSender, + Clock.systemUTC(), + rateLimiterCreator + ); + var requestManager = RequestManagerTests.createMock(requestSender, "id", RateLimitSettings.DISABLED_INSTANCE); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.execute(requestManager, new EmbeddingsInput(List.of(), null), null, listener); + + var waitToShutdown = new CountDownLatch(1); + var waitToReturnFromSend = new CountDownLatch(1); + // There is a request already queued, and its execution path will initiate shutting down the service + doAnswer(invocation -> { + waitToShutdown.countDown(); + waitToReturnFromSend.await(TIMEOUT.getSeconds(), TimeUnit.SECONDS); + return Void.TYPE; + }).when(requestSender).send(any(), any(), any(), any(), any()); + + Future executorTermination = submitShutdownRequest(waitToShutdown, waitToReturnFromSend, service); + + service.start(); + executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS); + + // The request manager that we create has RateLimitSettings.DISABLED_INSTANCE, so we should never call the rate limiter + verify(mockRateLimiter, never()).timeToReserve(anyInt()); + verify(mockRateLimiter, never()).reserve(anyInt()); + } + public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_AndExecutesTask() { var mockRateLimiter = mock(RateLimiter.class); when(mockRateLimiter.reserve(anyInt())).thenReturn(TimeValue.timeValueDays(0)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java index 9c95fbfdfa996..22f0121e84514 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsModelTests.java @@ -26,7 +26,7 @@ public static ElasticInferenceServiceSparseEmbeddingsModel createModel(String ur "id", TaskType.SPARSE_EMBEDDING, "service", - new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null), + new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.of(url), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java index 8b6f872b6ccba..fc7563f923135 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests.java @@ -7,25 +7,31 @@ package org.elasticsearch.xpack.inference.services.elastic; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.elasticsearch.ElserModels; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; import java.util.HashMap; import java.util.Map; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElserModelsTests.randomElserModel; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; -public class ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase< +public class ElasticInferenceServiceSparseEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase< ElasticInferenceServiceSparseEmbeddingsServiceSettings> { @Override @@ -53,23 +59,82 @@ public void testFromMap() { ConfigurationParseContext.REQUEST ); - assertThat(serviceSettings, is(new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, null, null))); + assertThat(serviceSettings, is(new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, null))); + } + + public void testFromMap_DoesNotRemoveRateLimitField_DoesNotThrowValidationException_PersistentContext() { + var modelId = "my-model-id"; + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + var serviceSettings = ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.PERSISTENT); + + assertThat(map, is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)))); + assertThat(serviceSettings, is(new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, null))); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + + public void testFromMap_DoesNotRemoveRateLimitField_DoesNotThrowValidationException_WhenRateLimitFieldDoesNotExist() { + var modelId = "my-model-id"; + var map = new HashMap(Map.of(ServiceFields.MODEL_ID, modelId)); + var serviceSettings = ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.PERSISTENT); + + assertThat(map, anEmptyMap()); + assertThat(serviceSettings, is(new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, null))); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + + public void testFromMap_DoesThrowValidationException_WhenRateLimitFieldDoesExist_RequestContext() { + var modelId = "my-model-id"; + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + var exception = expectThrows( + ValidationException.class, + () -> ElasticInferenceServiceSparseEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST) + ); + + assertThat( + exception.getMessage(), + containsString( + "[service_settings] rate limit settings are not permitted for service [elastic] and task type [sparse_embedding]" + ) + ); + assertThat(map, is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)))); } public void testToXContent_WritesAllFields() throws IOException { var modelId = ElserModels.ELSER_V1_MODEL; var maxInputTokens = 10; - var serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens, null); + var serviceSettings = new ElasticInferenceServiceSparseEmbeddingsServiceSettings(modelId, maxInputTokens); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(Strings.format(""" - {"model_id":"%s","max_input_tokens":%d,"rate_limit":{"requests_per_minute":1000}}""", modelId, maxInputTokens))); + {"model_id":"%s","max_input_tokens":%d}""", modelId, maxInputTokens))); } public static ElasticInferenceServiceSparseEmbeddingsServiceSettings createRandom() { - return new ElasticInferenceServiceSparseEmbeddingsServiceSettings(randomElserModel(), randomNonNegativeInt(), null); + return new ElasticInferenceServiceSparseEmbeddingsServiceSettings(randomElserModel(), randomNonNegativeInt()); + } + + @Override + protected ElasticInferenceServiceSparseEmbeddingsServiceSettings mutateInstanceForVersion( + ElasticInferenceServiceSparseEmbeddingsServiceSettings instance, + TransportVersion version + ) { + return instance; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index cb130033df004..d660f395250dd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -198,6 +198,28 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa } } + public void testParseRequestConfig_ThrowsWhenRateLimitFieldExistsInServiceSettingsMap() throws IOException { + try (var service = createServiceWithMockSender()) { + Map serviceSettings = new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + ElserModels.ELSER_V2_MODEL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + + var config = getRequestConfigMap(serviceSettings, Map.of(), Map.of()); + + var failureListener = getModelListenerForException( + ValidationException.class, + "Validation Failed: 1: [service_settings] rate limit settings are not permitted for " + + "service [elastic] and task type [sparse_embedding];" + ); + service.parseRequestConfig("id", TaskType.SPARSE_EMBEDDING, config, failureListener); + } + } + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { try (var service = createServiceWithMockSender()) { var taskSettings = Map.of("extra_key", (Object) "value"); @@ -298,6 +320,39 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists } } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenRateLimitFieldExistsInServiceSettings() throws IOException { + try (var service = createServiceWithMockSender()) { + Map serviceSettingsMap = new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + ElserModels.ELSER_V2_MODEL, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, Map.of(), Map.of()); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(ElasticInferenceServiceSparseEmbeddingsModel.class)); + + var parsedModel = (ElasticInferenceServiceSparseEmbeddingsModel) model; + assertThat(parsedModel.getServiceSettings().modelId(), is(ElserModels.ELSER_V2_MODEL)); + assertThat(parsedModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(parsedModel.getSecretSettings(), is(EmptySecretSettings.INSTANCE)); + assertThat( + serviceSettingsMap, + is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100))) + ); + } + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { try (var service = createServiceWithMockSender()) { var taskSettings = Map.of("extra_key", (Object) "value"); @@ -687,7 +742,7 @@ public void testUnifiedCompletionInfer_PropagatesProductUseCaseHeader() throws I "id", TaskType.CHAT_COMPLETION, "elastic", - new ElasticInferenceServiceCompletionServiceSettings("my-model-id", new RateLimitSettings(100)), + new ElasticInferenceServiceCompletionServiceSettings("my-model-id"), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.of(elasticInferenceServiceURL) @@ -1382,7 +1437,7 @@ private InferenceEventsAssertion testUnifiedStream(int responseCode, String resp "id", TaskType.COMPLETION, "elastic", - new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), + new ElasticInferenceServiceCompletionServiceSettings("model_id"), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.of(elasticInferenceServiceURL) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java index 617ddef5a9910..e42430b6512f5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/authorization/ElasticInferenceServiceAuthorizationHandlerTests.java @@ -257,7 +257,7 @@ private static Map initDefaultEndpoints() { defaultEndpointId("rainbow-sprinkles"), TaskType.CHAT_COMPLETION, "test", - new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles", null), + new ElasticInferenceServiceCompletionServiceSettings("rainbow-sprinkles"), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.EMPTY_INSTANCE @@ -270,7 +270,7 @@ private static Map initDefaultEndpoints() { defaultEndpointId("elser-2"), TaskType.SPARSE_EMBEDDING, "test", - new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null, null), + new ElasticInferenceServiceSparseEmbeddingsServiceSettings("elser-2", null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.EMPTY_INSTANCE, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java index 51945776b4f9e..58750e7d8c456 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionModelTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.util.List; @@ -26,7 +25,7 @@ public void testOverridingModelId() { "id", TaskType.COMPLETION, "elastic", - new ElasticInferenceServiceCompletionServiceSettings("model_id", new RateLimitSettings(100)), + new ElasticInferenceServiceCompletionServiceSettings("model_id"), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.of("url") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java index c530ff5c03482..91ad2d846f17b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/completion/ElasticInferenceServiceCompletionServiceSettingsTests.java @@ -7,17 +7,18 @@ package org.elasticsearch.xpack.inference.services.elastic.completion; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import java.io.IOException; import java.util.HashMap; @@ -25,8 +26,9 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; -public class ElasticInferenceServiceCompletionServiceSettingsTests extends AbstractWireSerializingTestCase< +public class ElasticInferenceServiceCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< ElasticInferenceServiceCompletionServiceSettings> { @Override @@ -53,7 +55,49 @@ public void testFromMap() { ConfigurationParseContext.REQUEST ); - assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(720L)))); + assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId))); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + + public void testFromMap_ThrowsValidationError_IfRateLimitFieldExists_ForRequestContext() { + var modelId = "my-model-id"; + + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + var exception = expectThrows( + ValidationException.class, + () -> ElasticInferenceServiceCompletionServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST) + ); + + assertThat( + exception.getMessage(), + containsString("[service_settings] rate limit settings are not permitted for service [elastic] and task type [chat_completion]") + ); + } + + public void testFromMap_DoesNotThrowValidationError_IfRateLimitFieldExists_ForPersistentContext() { + var modelId = "my-model-id"; + + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + + var serviceSettings = ElasticInferenceServiceCompletionServiceSettings.fromMap(map, ConfigurationParseContext.PERSISTENT); + + assertThat(map, is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)))); + assertThat(serviceSettings, is(new ElasticInferenceServiceCompletionServiceSettings(modelId))); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); } public void testFromMap_MissingModelId_ThrowsException() { @@ -67,17 +111,27 @@ public void testFromMap_MissingModelId_ThrowsException() { public void testToXContent_WritesAllFields() throws IOException { var modelId = "model_id"; - var serviceSettings = new ElasticInferenceServiceCompletionServiceSettings(modelId, new RateLimitSettings(1000)); + var serviceSettings = new ElasticInferenceServiceCompletionServiceSettings(modelId); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - assertThat(xContentResult, is(Strings.format(""" - {"model_id":"%s","rate_limit":{"requests_per_minute":1000}}""", modelId))); + assertThat(xContentResult, is(XContentHelper.stripWhitespace(Strings.format(""" + { + "model_id":"%s" + }""", modelId)))); } public static ElasticInferenceServiceCompletionServiceSettings createRandom() { - return new ElasticInferenceServiceCompletionServiceSettings(randomAlphaOfLength(4), RateLimitSettingsTests.createRandom()); + return new ElasticInferenceServiceCompletionServiceSettings(randomAlphaOfLength(4)); + } + + @Override + protected ElasticInferenceServiceCompletionServiceSettings mutateInstanceForVersion( + ElasticInferenceServiceCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java index fe0e4efc85a5b..be7f056754981 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsModelTests.java @@ -13,7 +13,6 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceComponents; -import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; public class ElasticInferenceServiceDenseTextEmbeddingsModelTests { @@ -22,13 +21,7 @@ public static ElasticInferenceServiceDenseTextEmbeddingsModel createModel(String "id", TaskType.TEXT_EMBEDDING, "elastic", - new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - modelId, - SimilarityMeasure.COSINE, - null, - null, - new RateLimitSettings(1000L) - ), + new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, SimilarityMeasure.COSINE, null, null), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.of(url), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java index a9263d5624dca..61657b75869e9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/densetextembeddings/ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.inference.services.elastic.densetextembeddings; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -22,9 +25,12 @@ import java.util.HashMap; import java.util.Map; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; -public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase< +public class ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase< ElasticInferenceServiceDenseTextEmbeddingsServiceSettings> { @Override @@ -72,48 +78,127 @@ public void testFromMap_Request_WithAllSettings() { assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens)); } + public void testFromMap_WithAllSettings_DoesNotRemoveRateLimitField_DoesNotThrowValidationException_PersistentContext() { + var modelId = "my-dense-model-id"; + var similarity = SimilarityMeasure.COSINE; + var dimensions = 384; + var maxInputTokens = 512; + + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarity.toString(), + ServiceFields.DIMENSIONS, + dimensions, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.PERSISTENT); + + assertThat(map, is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)))); + assertThat(serviceSettings.modelId(), is(modelId)); + assertThat(serviceSettings.similarity(), is(similarity)); + assertThat(serviceSettings.dimensions(), is(dimensions)); + assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens)); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + + public void testFromMap_WithAllSettings_DoesNotRemoveRateLimitField_ThrowsValidationException_RequestContext() { + var modelId = "my-dense-model-id"; + var similarity = SimilarityMeasure.COSINE; + var dimensions = 384; + var maxInputTokens = 512; + + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarity.toString(), + ServiceFields.DIMENSIONS, + dimensions, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + var exception = expectThrows( + ValidationException.class, + () -> ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST) + ); + + assertThat(map, is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)))); + assertThat( + exception.getMessage(), + containsString("[service_settings] rate limit settings are not permitted for service [elastic] and task type [text_embedding]") + ); + } + + public void testFromMap_WithAllSettings_DoesNotThrowValidationException_WhenRateLimitFieldDoesNotExist_RequestContext() { + var modelId = "my-dense-model-id"; + var similarity = SimilarityMeasure.COSINE; + var dimensions = 384; + var maxInputTokens = 512; + + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + ServiceFields.SIMILARITY, + similarity.toString(), + ServiceFields.DIMENSIONS, + dimensions, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens + ) + ); + var serviceSettings = ElasticInferenceServiceDenseTextEmbeddingsServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST); + + assertThat(map, anEmptyMap()); + assertThat(serviceSettings.modelId(), is(modelId)); + assertThat(serviceSettings.similarity(), is(similarity)); + assertThat(serviceSettings.dimensions(), is(dimensions)); + assertThat(serviceSettings.maxInputTokens(), is(maxInputTokens)); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + public void testToXContent_WritesAllFields() throws IOException { var modelId = "my-dense-model"; var similarity = SimilarityMeasure.DOT_PRODUCT; var dimensions = 1024; var maxInputTokens = 256; - var rateLimitSettings = new RateLimitSettings(5000); var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( modelId, similarity, dimensions, - maxInputTokens, - rateLimitSettings + maxInputTokens ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - String expectedResult = Strings.format( - """ - {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", - similarity, - dimensions, - maxInputTokens, - modelId, - rateLimitSettings.requestsPerTimeUnit() - ); + String expectedResult = Strings.format(""" + {"similarity":"%s","dimensions":%d,"max_input_tokens":%d,"model_id":"%s"}""", similarity, dimensions, maxInputTokens, modelId); assertThat(xContentResult, is(expectedResult)); } public void testToXContent_WritesOnlyNonNullFields() throws IOException { var modelId = "my-dense-model"; - var rateLimitSettings = new RateLimitSettings(2000); var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( modelId, null, // similarity null, // dimensions - null, // maxInputTokens - rateLimitSettings + null // maxInputTokens ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); @@ -121,20 +206,13 @@ public void testToXContent_WritesOnlyNonNullFields() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(Strings.format(""" - {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); + {"model_id":"%s"}""", modelId))); } public void testToXContentFragmentOfExposedFields() throws IOException { var modelId = "my-dense-model"; - var rateLimitSettings = new RateLimitSettings(1500); - var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - modelId, - SimilarityMeasure.COSINE, - 512, - 128, - rateLimitSettings - ); + var serviceSettings = new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, SimilarityMeasure.COSINE, 512, 128); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); builder.startObject(); @@ -143,8 +221,8 @@ public void testToXContentFragmentOfExposedFields() throws IOException { String xContentResult = Strings.toString(builder); // Only model_id and rate_limit should be in exposed fields - assertThat(xContentResult, is(Strings.format(""" - {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); + assertThat(xContentResult, is(XContentHelper.stripWhitespace(Strings.format(""" + {"model_id":"%s"}""", modelId)))); } public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRandom() { @@ -152,14 +230,15 @@ public static ElasticInferenceServiceDenseTextEmbeddingsServiceSettings createRa var similarity = SimilarityMeasure.COSINE; var dimensions = randomBoolean() ? randomIntBetween(1, 1024) : null; var maxInputTokens = randomBoolean() ? randomIntBetween(128, 256) : null; - var rateLimitSettings = randomBoolean() ? new RateLimitSettings(randomIntBetween(1, 10000)) : null; - return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings( - modelId, - similarity, - dimensions, - maxInputTokens, - rateLimitSettings - ); + return new ElasticInferenceServiceDenseTextEmbeddingsServiceSettings(modelId, similarity, dimensions, maxInputTokens); + } + + @Override + protected ElasticInferenceServiceDenseTextEmbeddingsServiceSettings mutateInstanceForVersion( + ElasticInferenceServiceDenseTextEmbeddingsServiceSettings instance, + TransportVersion version + ) { + return instance; } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java index f5da46915e13c..bfc47278a05b8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModelTests.java @@ -20,7 +20,7 @@ public static ElasticInferenceServiceRerankModel createModel(String url, String "id", TaskType.RERANK, "service", - new ElasticInferenceServiceRerankServiceSettings(modelId, null), + new ElasticInferenceServiceRerankServiceSettings(modelId), EmptyTaskSettings.INSTANCE, EmptySecretSettings.INSTANCE, ElasticInferenceServiceComponents.of(url) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java index 8066da9c43683..923ebb5475877 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankServiceSettingsTests.java @@ -7,12 +7,15 @@ package org.elasticsearch.xpack.inference.services.elastic.rerank; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -21,9 +24,12 @@ import java.util.HashMap; import java.util.Map; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; -public class ElasticInferenceServiceRerankServiceSettingsTests extends AbstractWireSerializingTestCase< +public class ElasticInferenceServiceRerankServiceSettingsTests extends AbstractBWCWireSerializationTestCase< ElasticInferenceServiceRerankServiceSettings> { @Override @@ -50,27 +56,90 @@ public void testFromMap() { ConfigurationParseContext.REQUEST ); - assertThat(serviceSettings, is(new ElasticInferenceServiceRerankServiceSettings(modelId, null))); + assertThat(serviceSettings, is(new ElasticInferenceServiceRerankServiceSettings(modelId))); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + + public void testFromMap_DoesNotRemoveRateLimitField_DoesNotThrowValidationException_PersistentContext() { + var modelId = "my-model-id"; + + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + + var serviceSettings = ElasticInferenceServiceRerankServiceSettings.fromMap(map, ConfigurationParseContext.PERSISTENT); + + assertThat(serviceSettings, is(new ElasticInferenceServiceRerankServiceSettings(modelId))); + assertThat(map, is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)))); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + + public void testFromMap_DoesNotThrowValidationException_WhenRateLimitFieldDoesNotExist() { + var modelId = "my-model-id"; + + var map = new HashMap(Map.of(ServiceFields.MODEL_ID, modelId)); + + var serviceSettings = ElasticInferenceServiceRerankServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST); + + assertThat(serviceSettings, is(new ElasticInferenceServiceRerankServiceSettings(modelId))); + assertThat(map, anEmptyMap()); + assertThat(serviceSettings.rateLimitSettings(), sameInstance(RateLimitSettings.DISABLED_INSTANCE)); + } + + public void testFromMap_DoesThrowValidationException_WhenRateLimitFieldDoesExist_RequestContext() { + var modelId = "my-model-id"; + + var map = new HashMap( + Map.of( + ServiceFields.MODEL_ID, + modelId, + RateLimitSettings.FIELD_NAME, + new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)) + ) + ); + + var exception = expectThrows( + ValidationException.class, + () -> ElasticInferenceServiceRerankServiceSettings.fromMap(map, ConfigurationParseContext.REQUEST) + ); + + assertThat( + exception.getMessage(), + containsString("[service_settings] rate limit settings are not permitted for service [elastic] and task type [rerank]") + ); + assertThat(map, is(Map.of(RateLimitSettings.FIELD_NAME, Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100)))); } public void testToXContent_WritesAllFields() throws IOException { var modelId = ".rerank-v1"; - var rateLimitSettings = new RateLimitSettings(100L); - var serviceSettings = new ElasticInferenceServiceRerankServiceSettings(modelId, rateLimitSettings); + var serviceSettings = new ElasticInferenceServiceRerankServiceSettings(modelId); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); - assertThat(xContentResult, is(Strings.format(""" - {"model_id":"%s","rate_limit":{"requests_per_minute":%d}}""", modelId, rateLimitSettings.requestsPerTimeUnit()))); + assertThat(xContentResult, is(XContentHelper.stripWhitespace(Strings.format(""" + {"model_id":"%s"}""", modelId)))); } public static ElasticInferenceServiceRerankServiceSettings createRandom() { - return new ElasticInferenceServiceRerankServiceSettings(randomRerankModel(), null); + return new ElasticInferenceServiceRerankServiceSettings(randomRerankModel()); } private static String randomRerankModel() { return randomFrom(".rerank-v1", ".rerank-v2"); } + + @Override + protected ElasticInferenceServiceRerankServiceSettings mutateInstanceForVersion( + ElasticInferenceServiceRerankServiceSettings instance, + TransportVersion version + ) { + return instance; + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java index 4a808087b6363..d36fd3d2c4790 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/settings/RateLimitSettingsTests.java @@ -8,13 +8,17 @@ package org.elasticsearch.xpack.inference.services.settings; import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; @@ -22,9 +26,10 @@ import java.util.Map; import java.util.concurrent.TimeUnit; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; -public class RateLimitSettingsTests extends AbstractWireSerializingTestCase { +public class RateLimitSettingsTests extends AbstractBWCWireSerializationTestCase { public static RateLimitSettings createRandom() { return new RateLimitSettings(randomLongBetween(1, 1000000)); @@ -54,6 +59,7 @@ public void testOf() { var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(100))); + assertTrue(res.isEnabled()); assertTrue(validation.validationErrors().isEmpty()); } @@ -65,6 +71,7 @@ public void testOf_UsesDefaultValue_WhenRateLimit_IsAbsent() { var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(1))); + assertTrue(res.isEnabled()); assertTrue(validation.validationErrors().isEmpty()); } @@ -74,6 +81,7 @@ public void testOf_UsesDefaultValue_WhenRequestsPerMinute_IsAbsent() { var res = RateLimitSettings.of(settings, new RateLimitSettings(1), validation, "test", ConfigurationParseContext.PERSISTENT); assertThat(res, is(new RateLimitSettings(1))); + assertTrue(res.isEnabled()); assertTrue(validation.validationErrors().isEmpty()); } @@ -102,6 +110,69 @@ public void testToXContent() throws IOException { {"rate_limit":{"requests_per_minute":100}}""")); } + public void testToXContent_WhenDisabled() throws IOException { + var settings = new RateLimitSettings(1, TimeUnit.MINUTES, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + settings.toXContent(builder, null); + builder.endObject(); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(XContentHelper.stripWhitespace(""" + { + }"""))); + } + + public void testRejectRateLimitFieldForRequestContext_DoesNotAddError_WhenRateLimitFieldDoesNotExist() { + var mapWithoutRateLimit = new HashMap(Map.of("abc", 100)); + var validation = new ValidationException(); + RateLimitSettings.rejectRateLimitFieldForRequestContext( + mapWithoutRateLimit, + "scope", + "service", + TaskType.CHAT_COMPLETION, + ConfigurationParseContext.REQUEST, + validation + ); + assertTrue(validation.validationErrors().isEmpty()); + } + + public void testRejectRateLimitFieldForRequestContext_DoesNotAddError_WhenRateLimitFieldDoesExist_PersistentContext() { + var mapWithRateLimit = new HashMap( + Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100))) + ); + var validation = new ValidationException(); + RateLimitSettings.rejectRateLimitFieldForRequestContext( + mapWithRateLimit, + "scope", + "service", + TaskType.CHAT_COMPLETION, + ConfigurationParseContext.PERSISTENT, + validation + ); + assertTrue(validation.validationErrors().isEmpty()); + } + + public void testRejectRateLimitFieldForRequestContext_DoesAddError_WhenRateLimitFieldDoesExist() { + var mapWithRateLimit = new HashMap( + Map.of(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 100))) + ); + var validation = new ValidationException(); + RateLimitSettings.rejectRateLimitFieldForRequestContext( + mapWithRateLimit, + "scope", + "service", + TaskType.CHAT_COMPLETION, + ConfigurationParseContext.REQUEST, + validation + ); + assertThat( + validation.getMessage(), + containsString("[scope] rate limit settings are not permitted for service [service] and task type [chat_completion]") + ); + } + @Override protected Writeable.Reader instanceReader() { return RateLimitSettings::new; @@ -116,4 +187,13 @@ protected RateLimitSettings createTestInstance() { protected RateLimitSettings mutateInstance(RateLimitSettings instance) throws IOException { return randomValueOtherThan(instance, RateLimitSettingsTests::createRandom); } + + @Override + protected RateLimitSettings mutateInstanceForVersion(RateLimitSettings instance, TransportVersion version) { + if (version.before(TransportVersions.INFERENCE_API_DISABLE_EIS_RATE_LIMITING)) { + return new RateLimitSettings(instance.requestsPerTimeUnit(), instance.timeUnit(), true); + } else { + return instance; + } + } }