diff --git a/docs/changelog/135262.yaml b/docs/changelog/135262.yaml new file mode 100644 index 0000000000000..35886a9373edd --- /dev/null +++ b/docs/changelog/135262.yaml @@ -0,0 +1,5 @@ +pr: 135262 +summary: Add usage stats for `semantic_text` fields +area: "Vector Search" +type: enhancement +issues: [] diff --git a/server/src/main/resources/transport/definitions/referable/inference_telemetry_added_semantic_text_stats.csv b/server/src/main/resources/transport/definitions/referable/inference_telemetry_added_semantic_text_stats.csv new file mode 100644 index 0000000000000..c225f99e2e2c2 --- /dev/null +++ b/server/src/main/resources/transport/definitions/referable/inference_telemetry_added_semantic_text_stats.csv @@ -0,0 +1 @@ +9182000 diff --git a/server/src/main/resources/transport/upper_bounds/9.2.csv b/server/src/main/resources/transport/upper_bounds/9.2.csv index d501e4094e60d..0dde9b0f74618 100644 --- a/server/src/main/resources/transport/upper_bounds/9.2.csv +++ b/server/src/main/resources/transport/upper_bounds/9.2.csv @@ -1 +1 @@ -index_reshard_shardcount_small,9181000 +inference_telemetry_added_semantic_text_stats,9182000 diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/ModelStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/ModelStats.java index df9cb03f246ae..6d25d45857a24 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/ModelStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/ModelStats.java @@ -7,9 +7,12 @@ package org.elasticsearch.xpack.core.inference.usage; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -19,28 +22,34 @@ public class ModelStats implements ToXContentObject, Writeable { + public static final NodeFeature SEMANTIC_TEXT_USAGE = new NodeFeature("inference.semantic_text_usage"); + + static final TransportVersion INFERENCE_TELEMETRY_ADDED_SEMANTIC_TEXT_STATS = TransportVersion.fromName( + "inference_telemetry_added_semantic_text_stats" + ); + private final String service; private final TaskType taskType; private long count; + @Nullable + private final SemanticTextStats semanticTextStats; - public ModelStats(String service, TaskType taskType) { - this(service, taskType, 0L); - } - - public ModelStats(String service, TaskType taskType, long count) { + public ModelStats(String service, TaskType taskType, long count, @Nullable SemanticTextStats semanticTextStats) { this.service = service; this.taskType = taskType; this.count = count; - } - - public ModelStats(ModelStats stats) { - this(stats.service, stats.taskType, stats.count); + this.semanticTextStats = semanticTextStats; } public ModelStats(StreamInput in) throws IOException { this.service = in.readString(); this.taskType = in.readEnum(TaskType.class); this.count = in.readLong(); + if (in.getTransportVersion().supports(INFERENCE_TELEMETRY_ADDED_SEMANTIC_TEXT_STATS)) { + this.semanticTextStats = in.readOptional(SemanticTextStats::new); + } else { + this.semanticTextStats = null; + } } public void add() { @@ -59,6 +68,11 @@ public long count() { return count; } + @Nullable + public SemanticTextStats semanticTextStats() { + return semanticTextStats; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -71,6 +85,9 @@ public void addXContentFragment(XContentBuilder builder, Params params) throws I builder.field("service", service); builder.field("task_type", taskType.name()); builder.field("count", count); + if (semanticTextStats != null) { + builder.field("semantic_text", semanticTextStats); + } } @Override @@ -78,6 +95,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(service); out.writeEnum(taskType); out.writeLong(count); + if (out.getTransportVersion().supports(INFERENCE_TELEMETRY_ADDED_SEMANTIC_TEXT_STATS)) { + out.writeOptionalWriteable(semanticTextStats); + } } @Override @@ -85,11 +105,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ModelStats that = (ModelStats) o; - return count == that.count && Objects.equals(service, that.service) && taskType == that.taskType; + return count == that.count + && Objects.equals(service, that.service) + && taskType == that.taskType + && Objects.equals(semanticTextStats, that.semanticTextStats); } @Override public int hashCode() { - return Objects.hash(service, taskType, count); + return Objects.hash(service, taskType, count, semanticTextStats); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/SemanticTextStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/SemanticTextStats.java new file mode 100644 index 0000000000000..71493484cad42 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/SemanticTextStats.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.usage; + +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public class SemanticTextStats implements ToXContentObject, Writeable { + + private static final String FIELD_COUNT = "field_count"; + private static final String INDICES_COUNT = "indices_count"; + private static final String INFERENCE_ID_COUNT = "inference_id_count"; + + private long fieldCount; + private long indicesCount; + private long inferenceIdCount; + + public SemanticTextStats() {} + + public SemanticTextStats(long fieldCount, long indicesCount, long inferenceIdCount) { + this.fieldCount = fieldCount; + this.indicesCount = indicesCount; + this.inferenceIdCount = inferenceIdCount; + } + + public SemanticTextStats(StreamInput in) throws IOException { + fieldCount = in.readVLong(); + indicesCount = in.readVLong(); + inferenceIdCount = in.readVLong(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVLong(fieldCount); + out.writeVLong(indicesCount); + out.writeVLong(inferenceIdCount); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_COUNT, fieldCount); + builder.field(INDICES_COUNT, indicesCount); + builder.field(INFERENCE_ID_COUNT, inferenceIdCount); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SemanticTextStats that = (SemanticTextStats) o; + return fieldCount == that.fieldCount && indicesCount == that.indicesCount && inferenceIdCount == that.inferenceIdCount; + } + + @Override + public int hashCode() { + return Objects.hash(fieldCount, indicesCount, inferenceIdCount); + } + + public long getFieldCount() { + return fieldCount; + } + + public long getIndicesCount() { + return indicesCount; + } + + public long getInferenceIdCount() { + return inferenceIdCount; + } + + public void addFieldCount(long fieldCount) { + this.fieldCount += fieldCount; + } + + public void incIndicesCount() { + this.indicesCount++; + } + + public void setInferenceIdCount(long inferenceIdCount) { + this.inferenceIdCount = inferenceIdCount; + } + + public boolean isEmpty() { + return fieldCount == 0 && indicesCount == 0 && inferenceIdCount == 0; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java index e020749b3e71e..83838bda39cee 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java @@ -7,16 +7,17 @@ package org.elasticsearch.xpack.core.inference.usage; +import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; import java.io.IOException; import static org.hamcrest.Matchers.equalTo; -public class ModelStatsTests extends AbstractWireSerializingTestCase { +public class ModelStatsTests extends AbstractBWCWireSerializationTestCase { @Override protected Writeable.Reader instanceReader() { @@ -33,16 +34,28 @@ protected ModelStats mutateInstance(ModelStats modelStats) throws IOException { String service = modelStats.service(); TaskType taskType = modelStats.taskType(); long count = modelStats.count(); - return switch (randomInt(2)) { - case 0 -> new ModelStats(randomValueOtherThan(service, ESTestCase::randomIdentifier), taskType, count); - case 1 -> new ModelStats(service, randomValueOtherThan(taskType, () -> randomFrom(TaskType.values())), count); - case 2 -> new ModelStats(service, taskType, randomValueOtherThan(count, ESTestCase::randomLong)); + SemanticTextStats semanticTextStats = modelStats.semanticTextStats(); + return switch (randomInt(3)) { + case 0 -> new ModelStats(randomValueOtherThan(service, ESTestCase::randomIdentifier), taskType, count, semanticTextStats); + case 1 -> new ModelStats( + service, + randomValueOtherThan(taskType, () -> randomFrom(TaskType.values())), + count, + semanticTextStats + ); + case 2 -> new ModelStats(service, taskType, randomValueOtherThan(count, ESTestCase::randomLong), semanticTextStats); + case 3 -> new ModelStats( + service, + taskType, + count, + randomValueOtherThan(semanticTextStats, SemanticTextStatsTests::createRandomInstance) + ); default -> throw new IllegalArgumentException(); }; } public void testAdd() { - ModelStats stats = new ModelStats("test_service", randomFrom(TaskType.values())); + ModelStats stats = new ModelStats("test_service", randomFrom(TaskType.values()), 0, null); assertThat(stats.count(), equalTo(0L)); stats.add(); @@ -56,6 +69,20 @@ public void testAdd() { } public static ModelStats createRandomInstance() { - return new ModelStats(randomIdentifier(), randomFrom(TaskType.values()), randomLong()); + TaskType taskType = randomValueOtherThan(TaskType.ANY, () -> randomFrom(TaskType.values())); + return new ModelStats( + randomIdentifier(), + taskType, + randomLong(), + randomBoolean() ? SemanticTextStatsTests.createRandomInstance() : null + ); + } + + @Override + protected ModelStats mutateInstanceForVersion(ModelStats instance, TransportVersion version) { + if (version.supports(ModelStats.INFERENCE_TELEMETRY_ADDED_SEMANTIC_TEXT_STATS) == false) { + return new ModelStats(instance.service(), instance.taskType(), instance.count(), null); + } + return instance; } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/SemanticTextStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/SemanticTextStatsTests.java new file mode 100644 index 0000000000000..9a49b4f495099 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/SemanticTextStatsTests.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.usage; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; + +public class SemanticTextStatsTests extends AbstractBWCWireSerializationTestCase { + + @Override + protected Writeable.Reader instanceReader() { + return SemanticTextStats::new; + } + + @Override + protected SemanticTextStats createTestInstance() { + return createRandomInstance(); + } + + static SemanticTextStats createRandomInstance() { + return new SemanticTextStats(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong()); + } + + @Override + protected SemanticTextStats mutateInstance(SemanticTextStats instance) throws IOException { + return switch (randomInt(2)) { + case 0 -> new SemanticTextStats( + randomValueOtherThan(instance.getFieldCount(), ESTestCase::randomNonNegativeLong), + instance.getIndicesCount(), + instance.getInferenceIdCount() + ); + case 1 -> new SemanticTextStats( + instance.getFieldCount(), + randomValueOtherThan(instance.getIndicesCount(), ESTestCase::randomNonNegativeLong), + instance.getInferenceIdCount() + ); + case 2 -> new SemanticTextStats( + instance.getFieldCount(), + instance.getIndicesCount(), + randomValueOtherThan(instance.getInferenceIdCount(), ESTestCase::randomNonNegativeLong) + ); + default -> throw new IllegalArgumentException(); + }; + } + + public void testDefaultConstructor() { + var stats = new SemanticTextStats(); + assertThat(stats.getFieldCount(), equalTo(0L)); + assertThat(stats.getIndicesCount(), equalTo(0L)); + assertThat(stats.getInferenceIdCount(), equalTo(0L)); + } + + public void testAddFieldCount() { + var stats = new SemanticTextStats(); + stats.addFieldCount(10L); + assertThat(stats.getFieldCount(), equalTo(10L)); + stats.addFieldCount(32L); + assertThat(stats.getFieldCount(), equalTo(42L)); + } + + public void testIsEmpty() { + assertThat(new SemanticTextStats().isEmpty(), is(true)); + assertThat(new SemanticTextStats(randomLongBetween(1, Long.MAX_VALUE), 0, 0).isEmpty(), is(false)); + assertThat(new SemanticTextStats(0, randomLongBetween(1, Long.MAX_VALUE), 0).isEmpty(), is(false)); + assertThat(new SemanticTextStats(0, 0, randomLongBetween(1, Long.MAX_VALUE)).isEmpty(), is(false)); + } + + @Override + protected SemanticTextStats mutateInstanceForVersion(SemanticTextStats instance, TransportVersion version) { + return instance; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 3ade24c89c82e..cf794d0e124e1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -9,6 +9,7 @@ import org.elasticsearch.features.FeatureSpecification; import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.xpack.core.inference.usage.ModelStats; import org.elasticsearch.xpack.inference.mapper.SemanticInferenceMetadataFieldsMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.queries.InterceptedInferenceQueryBuilder; @@ -94,7 +95,8 @@ public Set getTestFeatures() { SemanticQueryBuilder.SEMANTIC_QUERY_MULTIPLE_INFERENCE_IDS, SemanticQueryBuilder.SEMANTIC_QUERY_FILTER_FIELD_CAPS_FIX, InterceptedInferenceQueryBuilder.NEW_SEMANTIC_QUERY_INTERCEPTORS, - TEXT_SIMILARITY_RERANKER_SNIPPETS + TEXT_SIMILARITY_RERANKER_SNIPPETS, + ModelStats.SEMANTIC_TEXT_USAGE ) ); testFeatures.addAll(getFeatures()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java index 32a3888368d2e..609a1e4df62d8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java @@ -14,6 +14,9 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.ModelConfigurations; @@ -29,9 +32,21 @@ import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; import org.elasticsearch.xpack.core.inference.usage.ModelStats; +import org.elasticsearch.xpack.core.inference.usage.SemanticTextStats; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import java.util.TreeMap; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; @@ -39,6 +54,15 @@ public class TransportInferenceUsageAction extends XPackUsageFeatureTransportAct private final Logger logger = LogManager.getLogger(TransportInferenceUsageAction.class); + // Some of the default models have optimized variants for linux that will have the following suffix. + private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64"; + + private static final EnumSet TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT = EnumSet.of( + TaskType.TEXT_EMBEDDING, + TaskType.SPARSE_EMBEDDING + ); + + private final ModelRegistry modelRegistry; private final Client client; @Inject @@ -47,9 +71,11 @@ public TransportInferenceUsageAction( ClusterService clusterService, ThreadPool threadPool, ActionFilters actionFilters, + ModelRegistry modelRegistry, Client client ) { super(XPackUsageFeatureAction.INFERENCE.name(), transportService, clusterService, threadPool, actionFilters); + this.modelRegistry = modelRegistry; this.client = new OriginSettingClient(client, ML_ORIGIN); } @@ -62,17 +88,235 @@ protected void localClusterStateOperation( ) { GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false); client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, ActionListener.wrap(response -> { - Map stats = new TreeMap<>(); - for (ModelConfigurations model : response.getEndpoints()) { - String statKey = model.getService() + ":" + model.getTaskType().name(); - ModelStats stat = stats.computeIfAbsent(statKey, key -> new ModelStats(model.getService(), model.getTaskType())); - stat.add(); - } - InferenceFeatureSetUsage usage = new InferenceFeatureSetUsage(stats.values()); - listener.onResponse(new XPackUsageFeatureResponse(usage)); + listener.onResponse( + new XPackUsageFeatureResponse(collectUsage(response.getEndpoints(), state.getMetadata().indicesAllProjects())) + ); }, e -> { logger.warn(Strings.format("Retrieving inference usage failed with error: %s", e.getMessage()), e); listener.onResponse(new XPackUsageFeatureResponse(InferenceFeatureSetUsage.EMPTY)); })); } + + private InferenceFeatureSetUsage collectUsage(List endpoints, Iterable indicesMetadata) { + Map>> inferenceFieldsByIndexServiceAndTask = + mapInferenceFieldsByIndexServiceAndTask(indicesMetadata, endpoints); + Map endpointStats = new TreeMap<>(); + addStatsByServiceAndTask(inferenceFieldsByIndexServiceAndTask, endpoints, endpointStats); + addStatsForDefaultModelsCompatibleWithSemanticText(inferenceFieldsByIndexServiceAndTask, endpoints, endpointStats); + return new InferenceFeatureSetUsage(endpointStats.values()); + } + + /** + * Returns a map whose keys are the inference service and task_type and the values are maps of index names to inference fields. + * Inference fields in system or hidden indices are excluded. + */ + private static Map>> mapInferenceFieldsByIndexServiceAndTask( + Iterable indicesMetadata, + List endpoints + ) { + Map inferenceIdToEndpoint = endpoints.stream() + .collect(Collectors.toMap(ModelConfigurations::getInferenceEntityId, Function.identity())); + Map>> inferenceFieldByIndexServiceAndTask = new HashMap<>(); + for (IndexMetadata indexMetadata : indicesMetadata) { + if (indexMetadata.isSystem() || indexMetadata.isHidden()) { + // Usage for system or hidden indices should be reported through the corresponding application usage + continue; + } + indexMetadata.getInferenceFields() + .values() + .stream() + .filter(field -> inferenceIdToEndpoint.containsKey(field.getInferenceId())) + .forEach(field -> { + ModelConfigurations endpoint = inferenceIdToEndpoint.get(field.getInferenceId()); + Map> fieldsByIndex = inferenceFieldByIndexServiceAndTask.computeIfAbsent( + new ServiceAndTaskType(endpoint.getService(), endpoint.getTaskType()), + key -> new HashMap<>() + ); + fieldsByIndex.computeIfAbsent(indexMetadata.getIndex().getName(), key -> new ArrayList<>()).add(field); + }); + } + return inferenceFieldByIndexServiceAndTask; + } + + /** + * Adds inference usage stats for each service and task type combination. + * In addition, adds aggregate usage stats per task type across all services. + * Those aggregate stats have "_all" as the service name. + */ + private static void addStatsByServiceAndTask( + Map>> inferenceFieldsByIndexServiceAndTask, + List endpoints, + Map endpointStats + ) { + for (ModelConfigurations model : endpoints) { + endpointStats.computeIfAbsent( + new ServiceAndTaskType(model.getService(), model.getTaskType()).toString(), + key -> createEmptyStats(model) + ).add(); + + endpointStats.computeIfAbsent( + new ServiceAndTaskType(Metadata.ALL, model.getTaskType()).toString(), + key -> createEmptyStats(Metadata.ALL, model.getTaskType()) + ).add(); + } + + inferenceFieldsByIndexServiceAndTask.forEach( + (serviceAndTaskType, inferenceFieldsByIndex) -> addSemanticTextStats( + inferenceFieldsByIndex, + endpointStats.get(serviceAndTaskType.toString()) + ) + ); + addTopLevelStatsByTask(inferenceFieldsByIndexServiceAndTask, endpointStats); + } + + private static ModelStats createEmptyStats(ModelConfigurations model) { + return createEmptyStats(model.getService(), model.getTaskType()); + } + + private static ModelStats createEmptyStats(String service, TaskType taskType) { + return new ModelStats( + service, + taskType, + 0, + TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(taskType) ? new SemanticTextStats() : null + ); + } + + private static void addTopLevelStatsByTask( + Map>> inferenceFieldsByIndexServiceAndTask, + Map endpointStats + ) { + for (TaskType taskType : TaskType.values()) { + if (taskType == TaskType.ANY) { + continue; + } + ModelStats allStatsForTaskType = endpointStats.computeIfAbsent( + new ServiceAndTaskType(Metadata.ALL, taskType).toString(), + key -> createEmptyStats(Metadata.ALL, taskType) + ); + if (TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(taskType)) { + Map> inferenceFieldsByIndex = inferenceFieldsByIndexServiceAndTask.entrySet() + .stream() + .filter(e -> e.getKey().taskType == taskType) + .flatMap(m -> m.getValue().entrySet().stream()) + .collect( + Collectors.toMap( + Map.Entry::getKey, + Map.Entry::getValue, + (l1, l2) -> Stream.concat(l1.stream(), l2.stream()).toList() + ) + ); + addSemanticTextStats(inferenceFieldsByIndex, allStatsForTaskType); + } + } + } + + private static void addSemanticTextStats(Map> inferenceFieldsByIndex, ModelStats stat) { + Set inferenceIds = new HashSet<>(); + for (List inferenceFields : inferenceFieldsByIndex.values()) { + stat.semanticTextStats().addFieldCount(inferenceFields.size()); + stat.semanticTextStats().incIndicesCount(); + inferenceFields.forEach(field -> inferenceIds.add(field.getInferenceId())); + } + stat.semanticTextStats().setInferenceIdCount(inferenceIds.size()); + } + + /** + * Adds stats for default models that are compatible with semantic_text. + * In particular, default models are considered models that are associated with default inference + * endpoints as per the {@code ModelRegistry}. The service name for default model stats is "_{service}_{modelId}". + * Each of those stats contains usage for all endpoints that use that model, including non-default endpoints. + */ + private void addStatsForDefaultModelsCompatibleWithSemanticText( + Map>> inferenceFieldsByIndexServiceAndTask, + List endpoints, + Map endpointStats + ) { + Map endpointIdToModelId = endpoints.stream() + .filter(endpoint -> endpoint.getServiceSettings().modelId() != null) + .collect(Collectors.toMap(ModelConfigurations::getInferenceEntityId, e -> stripLinuxSuffix(e.getServiceSettings().modelId()))); + Map defaultModelsToEndpointCount = + createStatsKeysWithEndpointCountsForDefaultModelsCompatibleWithSemanticText(endpoints); + for (Map.Entry defaultModelStatsKeyToEndpointCount : defaultModelsToEndpointCount.entrySet()) { + DefaultModelStatsKey statKey = defaultModelStatsKeyToEndpointCount.getKey(); + Map> fieldsByIndex = inferenceFieldsByIndexServiceAndTask.getOrDefault( + new ServiceAndTaskType(statKey.service, statKey.taskType), + Map.of() + ); + // Now that we have all inference fields for this service and task type, we want to keep only the ones that + // reference the current default model. + fieldsByIndex = filterFields(fieldsByIndex, f -> statKey.modelId.equals(endpointIdToModelId.get(f.getInferenceId()))); + ModelStats stats = new ModelStats( + statKey.toString(), + statKey.taskType, + defaultModelStatsKeyToEndpointCount.getValue(), + new SemanticTextStats() + ); + addSemanticTextStats(fieldsByIndex, stats); + endpointStats.put(statKey.toString(), stats); + } + } + + private Map createStatsKeysWithEndpointCountsForDefaultModelsCompatibleWithSemanticText( + List endpoints + ) { + // We consider models to be default if they are associated with a default inference endpoint. + // Note that endpoints could have a null model id, in which case we don't consider them default as this + // may only happen for external services. + Set modelIds = endpoints.stream() + .filter(endpoint -> TASK_TYPES_WITH_SEMANTIC_TEXT_SUPPORT.contains(endpoint.getTaskType())) + .filter(endpoint -> modelRegistry.containsDefaultConfigId(endpoint.getInferenceEntityId())) + .filter(endpoint -> endpoint.getServiceSettings().modelId() != null) + .map(endpoint -> stripLinuxSuffix(endpoint.getServiceSettings().modelId())) + .collect(Collectors.toSet()); + return endpoints.stream() + .filter(endpoint -> endpoint.getServiceSettings().modelId() != null) + .filter(endpoint -> modelIds.contains(stripLinuxSuffix(endpoint.getServiceSettings().modelId()))) + .map( + endpoint -> new DefaultModelStatsKey( + endpoint.getService(), + endpoint.getTaskType(), + stripLinuxSuffix(endpoint.getServiceSettings().modelId()) + ) + ) + .collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + } + + private static Map> filterFields( + Map> fieldsByIndex, + Predicate predicate + ) { + Map> filtered = new HashMap<>(); + for (Map.Entry> entry : fieldsByIndex.entrySet()) { + List filteredFields = entry.getValue().stream().filter(predicate).toList(); + if (filteredFields.isEmpty() == false) { + filtered.put(entry.getKey(), filteredFields); + } + } + return filtered; + } + + private static String stripLinuxSuffix(String modelId) { + if (modelId.endsWith(MODEL_ID_LINUX_SUFFIX)) { + return modelId.substring(0, modelId.length() - MODEL_ID_LINUX_SUFFIX.length()); + } + return modelId; + } + + private record DefaultModelStatsKey(String service, TaskType taskType, String modelId) { + + @Override + public String toString() { + // Inference ids cannot start with '_'. Thus, default stats do to avoid conflicts with user-defined inference ids. + return "_" + service + "_" + modelId.replace('.', '_'); + } + } + + private record ServiceAndTaskType(String service, TaskType taskType) { + + @Override + public String toString() { + return service + ":" + taskType.name(); + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java index 4fe6cd7995850..d56b3fd8037c7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java @@ -12,8 +12,14 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.metadata.ProjectId; +import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.index.IndexVersion; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskType; @@ -32,14 +38,26 @@ import org.elasticsearch.xpack.core.action.XPackUsageFeatureResponse; import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage; import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction; +import org.elasticsearch.xpack.core.inference.usage.ModelStats; +import org.elasticsearch.xpack.core.inference.usage.SemanticTextStats; import org.elasticsearch.xpack.core.watcher.support.xcontent.XContentSource; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.After; import org.junit.Before; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; +import static org.elasticsearch.cluster.metadata.IndexMetadata.INDEX_HIDDEN_SETTING; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.core.Is.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -48,7 +66,11 @@ public class TransportInferenceUsageActionTests extends ESTestCase { + private static final SemanticTextStats EMPTY_SEMANTIC_TEXT_STATS = new SemanticTextStats(); + private Client client; + private ModelRegistry modelRegistry; + private ClusterState clusterState; private TransportInferenceUsageAction action; @Before @@ -57,6 +79,10 @@ public void init() { ThreadPool threadPool = new TestThreadPool("test"); when(client.threadPool()).thenReturn(threadPool); + modelRegistry = mock(ModelRegistry.class); + + givenClusterState(Map.of()); + TransportService transportService = MockUtils.setupTransportServiceWithThreadpoolExecutor(mock(ThreadPool.class)); action = new TransportInferenceUsageAction( @@ -64,6 +90,7 @@ public void init() { mock(ClusterService.class), mock(ThreadPool.class), mock(ActionFilters.class), + modelRegistry, client ); } @@ -73,27 +100,399 @@ public void close() { client.threadPool().shutdown(); } - public void test() throws Exception { + public void testGivenServices_NoInferenceFields() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations("model-001", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("model-id-001")), + new ModelConfigurations("model-002", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("model-id-002")), + new ModelConfigurations("model-003", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mockServiceSettings("model-id-003")), + new ModelConfigurations("model-004", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("model-id-004")), + new ModelConfigurations("model-005", TaskType.SPARSE_EMBEDDING, "openai", mockServiceSettings("model-id-005")), + new ModelConfigurations("model-006", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mockServiceSettings("model-id-006")) + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(8)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 3, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 3, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 5, new ModelStats("hugging_face_elser", TaskType.SPARSE_EMBEDDING, 2, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 6, new ModelStats("openai", TaskType.SPARSE_EMBEDDING, 1, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 7, new ModelStats("openai", TaskType.TEXT_EMBEDDING, 3, EMPTY_SEMANTIC_TEXT_STATS)); + } + + public void testGivenFieldRefersToMissingInferenceEndpoint() throws Exception { + givenInferenceEndpoints(); + givenInferenceFields(Map.of("index_1", List.of(new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null)))); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(5)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + } + + public void testGivenVariousServicesAndInferenceFields() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations("endpoint-001", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("model-id-001")), + new ModelConfigurations("endpoint-002", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("model-id-002")), + new ModelConfigurations("endpoint-003", TaskType.SPARSE_EMBEDDING, "openai", mockServiceSettings("model-id-003")), + new ModelConfigurations("endpoint-004", TaskType.SPARSE_EMBEDDING, "openai", mockServiceSettings("model-id-004")), + new ModelConfigurations("endpoint-005", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings("model-id-005")), + new ModelConfigurations("endpoint-006", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings("model-id-006")), + new ModelConfigurations("endpoint-007", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings("model-id-007")) // unused + ); + + givenInferenceFields( + Map.of( + "index_1", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null), + new InferenceFieldMetadata("semantic-4", "endpoint-002", new String[0], null), + new InferenceFieldMetadata("semantic-5", "endpoint-002", new String[0], null), + new InferenceFieldMetadata("semantic-6", "endpoint-003", new String[0], null), + new InferenceFieldMetadata("semantic-7", "endpoint-004", new String[0], null), + new InferenceFieldMetadata("semantic-8", "endpoint-005", new String[0], null) + ), + "index_2", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-003", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-003", new String[0], null), + new InferenceFieldMetadata("semantic-4", "endpoint-004", new String[0], null), + new InferenceFieldMetadata("semantic-5", "endpoint-004", new String[0], null), + new InferenceFieldMetadata("semantic-6", "endpoint-005", new String[0], null) + ), + "index_3", + List.of(new InferenceFieldMetadata("semantic-1", "endpoint-006", new String[0], null)) + ) + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(8)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 5, new SemanticTextStats(9, 3, 4))); + assertStats(response, 5, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 3, new SemanticTextStats(3, 3, 2))); + assertStats(response, 6, new ModelStats("openai", TaskType.SPARSE_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 7, new ModelStats("openai", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + } + + public void testGivenServices_InferenceFieldsReferencingDefaultModels() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations(".endpoint-001", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001")), + new ModelConfigurations(".endpoint-002", TaskType.SPARSE_EMBEDDING, "eis", mockServiceSettings(".model-id-002")), + new ModelConfigurations("endpoint-003", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("model-id-003")), + new ModelConfigurations(".endpoint-004", TaskType.SPARSE_EMBEDDING, "openai", mockServiceSettings("model-id-004")), + new ModelConfigurations("endpoint-005", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001")) + ); + + givenDefaultEndpoints(".endpoint-001", ".endpoint-002", ".endpoint-004"); + + givenInferenceFields( + Map.of( + "index_1", + List.of( + new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-3", ".endpoint-002", new String[0], null), + new InferenceFieldMetadata("semantic-4", ".endpoint-004", new String[0], null), + new InferenceFieldMetadata("semantic-5", "endpoint-005", new String[0], null) + ), + "index_2", + List.of( + new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", ".endpoint-004", new String[0], null), + new InferenceFieldMetadata("semantic-3", ".endpoint-004", new String[0], null), + new InferenceFieldMetadata("semantic-4", "endpoint-005", new String[0], null), + new InferenceFieldMetadata("semantic-5", "endpoint-005", new String[0], null) + ) + ) + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(12)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 2, new SemanticTextStats(4, 2, 2))); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 3, new SemanticTextStats(6, 2, 2))); + assertStats(response, 5, new ModelStats("_eis__model-id-001", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 6, new ModelStats("_eis__model-id-002", TaskType.SPARSE_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + assertStats(response, 7, new ModelStats("_openai_model-id-004", TaskType.SPARSE_EMBEDDING, 1, new SemanticTextStats(3, 2, 1))); + assertStats(response, 8, new ModelStats("eis", TaskType.SPARSE_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + assertStats(response, 9, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 10, new ModelStats("openai", TaskType.SPARSE_EMBEDDING, 1, new SemanticTextStats(3, 2, 1))); + assertStats(response, 11, new ModelStats("openai", TaskType.TEXT_EMBEDDING, 1, EMPTY_SEMANTIC_TEXT_STATS)); + } + + public void testGivenDefaultModelWithLinuxSuffix() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations(".endpoint-001", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001_linux-x86_64")), + new ModelConfigurations("endpoint-002", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001_linux-x86_64")) + ); + + givenDefaultEndpoints(".endpoint-001"); + + givenInferenceFields( + Map.of( + "index_1", + List.of( + new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null) + ), + "index_2", + List.of( + new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null) + ) + ) + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(7)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 5, new ModelStats("_eis__model-id-001", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 6, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + } + + public void testGivenSameDefaultModelWithAndWithoutLinuxSuffix() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations(".endpoint-001", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001_linux-x86_64")), + new ModelConfigurations("endpoint-002", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001")) + ); + + givenDefaultEndpoints(".endpoint-001"); + + givenInferenceFields( + Map.of( + "index_1", + List.of( + new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null) + ), + "index_2", + List.of( + new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", ".endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null) + ) + ) + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(7)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 5, new ModelStats("_eis__model-id-001", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + assertStats(response, 6, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2))); + } + + public void testGivenExternalServiceModelIsNull() throws Exception { + givenInferenceEndpoints(new ModelConfigurations("endpoint-001", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings(null))); + givenInferenceFields(Map.of("index_1", List.of(new InferenceFieldMetadata("semantic", "endpoint-001", new String[0], null)))); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(6)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + assertStats(response, 5, new ModelStats("openai", TaskType.TEXT_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + } + + public void testGivenDuplicateServices() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations("endpoint-001", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("some-model")), + new ModelConfigurations("endpoint-002", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("some-model")) + ); + givenInferenceFields( + Map.of( + "index_1", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-002", new String[0], null) + ), + "index_2", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-002", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null) + ) + ) + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(6)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(5, 2, 2))); + assertStats(response, 5, new ModelStats("openai", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(5, 2, 2))); + } + + public void testShouldExcludeSystemIndexFields() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations("endpoint-001", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings("some-model")), + new ModelConfigurations("endpoint-002", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("some-model")) + ); + givenInferenceFields( + Map.of( + "index_1", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-002", new String[0], null) + ), + "index_2", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-002", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null) + ) + ), + Set.of("index_2"), + Set.of() + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(7)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(2, 1, 2))); + assertStats(response, 5, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + assertStats(response, 6, new ModelStats("openai", TaskType.TEXT_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + } + + public void testShouldExcludeHiddenIndexFields() throws Exception { + givenInferenceEndpoints( + new ModelConfigurations("endpoint-001", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings("some-model")), + new ModelConfigurations("endpoint-002", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings("some-model")) + ); + givenInferenceFields( + Map.of( + "index_1", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-002", new String[0], null) + ), + "index_2", + List.of( + new InferenceFieldMetadata("semantic-1", "endpoint-001", new String[0], null), + new InferenceFieldMetadata("semantic-2", "endpoint-002", new String[0], null), + new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null) + ) + ), + Set.of(), + Set.of("index_2") + ); + + XContentSource response = executeAction(); + + assertThat(response.getValue("models"), hasSize(7)); + assertStats(response, 0, new ModelStats("_all", TaskType.CHAT_COMPLETION, 0, null)); + assertStats(response, 1, new ModelStats("_all", TaskType.COMPLETION, 0, null)); + assertStats(response, 2, new ModelStats("_all", TaskType.RERANK, 0, null)); + assertStats(response, 3, new ModelStats("_all", TaskType.SPARSE_EMBEDDING, 0, EMPTY_SEMANTIC_TEXT_STATS)); + assertStats(response, 4, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(2, 1, 2))); + assertStats(response, 5, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + assertStats(response, 6, new ModelStats("openai", TaskType.TEXT_EMBEDDING, 1, new SemanticTextStats(1, 1, 1))); + } + + public void testFailureReturnsEmptyUsage() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("invalid field")); + return Void.TYPE; + }).when(client).execute(any(GetInferenceModelAction.class), any(), any()); + + var future = new PlainActionFuture(); + action.localClusterStateOperation(mock(Task.class), mock(XPackUsageRequest.class), clusterState, future); + + var usage = future.actionGet(TIMEOUT); + var inferenceUsage = (InferenceFeatureSetUsage) usage.getUsage(); + assertThat(inferenceUsage, is(InferenceFeatureSetUsage.EMPTY)); + } + + private void givenClusterState(Map indices) { + clusterState = ClusterState.builder(ClusterState.EMPTY_STATE) + .metadata(Metadata.builder().put(ProjectMetadata.builder(ProjectId.DEFAULT).indices(indices).build())) + .build(); + } + + private static ServiceSettings mockServiceSettings(String modelId) { + ServiceSettings serviceSettings = mock(ServiceSettings.class); + when(serviceSettings.modelId()).thenReturn(modelId); + return serviceSettings; + } + + private void givenInferenceEndpoints(ModelConfigurations... endpoints) { doAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; - listener.onResponse( - new GetInferenceModelAction.Response( - List.of( - new ModelConfigurations("model-001", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)), - new ModelConfigurations("model-002", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)), - new ModelConfigurations("model-003", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mock(ServiceSettings.class)), - new ModelConfigurations("model-004", TaskType.TEXT_EMBEDDING, "openai", mock(ServiceSettings.class)), - new ModelConfigurations("model-005", TaskType.SPARSE_EMBEDDING, "openai", mock(ServiceSettings.class)), - new ModelConfigurations("model-006", TaskType.SPARSE_EMBEDDING, "hugging_face_elser", mock(ServiceSettings.class)) - ) - ) - ); + listener.onResponse(new GetInferenceModelAction.Response(Arrays.asList(endpoints))); return Void.TYPE; }).when(client).execute(any(GetInferenceModelAction.class), any(), any()); + } + + private void givenInferenceFields(Map> inferenceFieldsByIndex) { + givenInferenceFields(inferenceFieldsByIndex, Set.of(), Set.of()); + } + private void givenInferenceFields( + Map> inferenceFieldsByIndex, + Set systemIndices, + Set hiddenIndices + ) { + Map indices = new HashMap<>(); + for (Map.Entry> entry : inferenceFieldsByIndex.entrySet()) { + String index = entry.getKey(); + IndexMetadata.Builder indexMetadata = IndexMetadata.builder(index) + .settings( + ESTestCase.settings(IndexVersion.current()) + .put(INDEX_HIDDEN_SETTING.getKey(), hiddenIndices.contains(index) ? "true" : "false") + ) + .numberOfShards(randomIntBetween(1, 5)) + .system(systemIndices.contains(index)) + .numberOfReplicas(1); + entry.getValue().forEach(indexMetadata::putInferenceField); + indices.put(index, indexMetadata.build()); + } + givenClusterState(indices); + } + + private XContentSource executeAction() throws ExecutionException, InterruptedException, IOException { PlainActionFuture future = new PlainActionFuture<>(); - action.localClusterStateOperation(mock(Task.class), mock(XPackUsageRequest.class), mock(ClusterState.class), future); + action.localClusterStateOperation(mock(Task.class), mock(XPackUsageRequest.class), clusterState, future); BytesStreamOutput out = new BytesStreamOutput(); future.get().getUsage().writeTo(out); @@ -105,31 +504,34 @@ public void test() throws Exception { XContentBuilder builder = XContentFactory.jsonBuilder(); usage.toXContent(builder, ToXContent.EMPTY_PARAMS); - XContentSource source = new XContentSource(builder); - assertThat(source.getValue("models"), hasSize(3)); - assertThat(source.getValue("models.0.service"), is("hugging_face_elser")); - assertThat(source.getValue("models.0.task_type"), is("SPARSE_EMBEDDING")); - assertThat(source.getValue("models.0.count"), is(2)); - assertThat(source.getValue("models.1.service"), is("openai")); - assertThat(source.getValue("models.1.task_type"), is("SPARSE_EMBEDDING")); - assertThat(source.getValue("models.1.count"), is(1)); - assertThat(source.getValue("models.2.service"), is("openai")); - assertThat(source.getValue("models.2.task_type"), is("TEXT_EMBEDDING")); - assertThat(source.getValue("models.2.count"), is(3)); + return new XContentSource(builder); } - public void testFailureReturnsEmptyUsage() { - doAnswer(invocation -> { - ActionListener listener = invocation.getArgument(2); - listener.onFailure(new IllegalArgumentException("invalid field")); - return Void.TYPE; - }).when(client).execute(any(GetInferenceModelAction.class), any(), any()); - - var future = new PlainActionFuture(); - action.localClusterStateOperation(mock(Task.class), mock(XPackUsageRequest.class), mock(ClusterState.class), future); + private void givenDefaultEndpoints(String... ids) { + for (String id : ids) { + when(modelRegistry.containsDefaultConfigId(id)).thenReturn(true); + } + } - var usage = future.actionGet(TIMEOUT); - var inferenceUsage = (InferenceFeatureSetUsage) usage.getUsage(); - assertThat(inferenceUsage, is(InferenceFeatureSetUsage.EMPTY)); + private static void assertStats(XContentSource source, int index, ModelStats stats) { + assertThat(source.getValue("models." + index + ".service"), is(stats.service())); + assertThat(source.getValue("models." + index + ".task_type"), is(stats.taskType().name())); + assertThat(((Integer) source.getValue("models." + index + ".count")).longValue(), equalTo(stats.count())); + if (stats.semanticTextStats() == null) { + assertThat(source.getValue("models." + index + ".semantic_text"), is(nullValue())); + } else { + assertThat( + ((Integer) source.getValue("models." + index + ".semantic_text.field_count")).longValue(), + equalTo(stats.semanticTextStats().getFieldCount()) + ); + assertThat( + ((Integer) source.getValue("models." + index + ".semantic_text.indices_count")).longValue(), + equalTo(stats.semanticTextStats().getIndicesCount()) + ); + assertThat( + ((Integer) source.getValue("models." + index + ".semantic_text.inference_id_count")).longValue(), + equalTo(stats.semanticTextStats().getInferenceIdCount()) + ); + } } } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_usage.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_usage.yml new file mode 100644 index 0000000000000..a2ba07c2b1e02 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/inference/inference_usage.yml @@ -0,0 +1,145 @@ +setup: + - requires: + cluster_features: "inference.semantic_text_usage" + reason: semantic text stats added to inference usage + +--- +"Test usage given default endpoints": + + - do: + xpack.usage: {} + + - match: { inference.available: true } + - match: { inference.enabled: true } + - length: { inference.models: 10 } + + - match: { inference.models.0.service: "_all" } + - match: { inference.models.0.task_type: "CHAT_COMPLETION" } + - match: { inference.models.0.count: 0 } + - not_exists: inference.models.0.semantic_text + + - match: { inference.models.1.service: "_all" } + - match: { inference.models.1.task_type: "COMPLETION" } + - match: { inference.models.1.count: 0 } + - not_exists: inference.models.1.semantic_text + + - match: { inference.models.2.service: "_all" } + - match: { inference.models.2.task_type: "RERANK" } + - match: { inference.models.2.count: 1 } + - not_exists: inference.models.2.semantic_text + + - match: { inference.models.3.service: "_all" } + - match: { inference.models.3.task_type: "SPARSE_EMBEDDING" } + - match: { inference.models.3.count: 1 } + - match: { inference.models.3.semantic_text: { field_count: 0, indices_count: 0, inference_id_count: 0} } + + - match: { inference.models.4.service: "_all" } + - match: { inference.models.4.task_type: "TEXT_EMBEDDING" } + - match: { inference.models.4.count: 1 } + - match: { inference.models.4.semantic_text: { field_count: 0, indices_count: 0, inference_id_count: 0} } + + - match: { inference.models.5.service: "_elasticsearch__elser_model_2" } + - match: { inference.models.5.task_type: "SPARSE_EMBEDDING" } + - match: { inference.models.5.count: 1 } + - match: { inference.models.5.semantic_text: { field_count: 0, indices_count: 0, inference_id_count: 0} } + + - match: { inference.models.6.service: "_elasticsearch__multilingual-e5-small" } + - match: { inference.models.6.task_type: "TEXT_EMBEDDING" } + - match: { inference.models.6.count: 1 } + - match: { inference.models.6.semantic_text: { field_count: 0, indices_count: 0, inference_id_count: 0} } + + - match: { inference.models.7.service: "elasticsearch" } + - match: { inference.models.7.task_type: "RERANK" } + - match: { inference.models.7.count: 1 } + - not_exists: inference.models.7.semantic_text + + - match: { inference.models.8.service: "elasticsearch" } + - match: { inference.models.8.task_type: "SPARSE_EMBEDDING" } + - match: { inference.models.8.count: 1 } + - match: { inference.models.8.semantic_text: { field_count: 0, indices_count: 0, inference_id_count: 0} } + + - match: { inference.models.9.service: "elasticsearch" } + - match: { inference.models.9.task_type: "TEXT_EMBEDDING" } + - match: { inference.models.9.count: 1 } + - match: { inference.models.9.semantic_text: { field_count: 0, indices_count: 0, inference_id_count: 0} } + +--- +"Test usage given default endpoints and semantic_text fields": + + - do: + indices.create: + index: test-index-1 + body: + mappings: + properties: + field_1: + type: semantic_text + field_2: + type: semantic_text + inference_id: .multilingual-e5-small-elasticsearch + + - do: + indices.create: + index: test-index-2 + body: + mappings: + properties: + field_1: + type: semantic_text + + - do: + xpack.usage: {} + + - match: { inference.available: true } + - match: { inference.enabled: true } + - length: { inference.models: 10 } + + - match: { inference.models.0.service: "_all" } + - match: { inference.models.0.task_type: "CHAT_COMPLETION" } + - match: { inference.models.0.count: 0 } + - not_exists: inference.models.0.semantic_text + + - match: { inference.models.1.service: "_all" } + - match: { inference.models.1.task_type: "COMPLETION" } + - match: { inference.models.1.count: 0 } + - not_exists: inference.models.1.semantic_text + + - match: { inference.models.2.service: "_all" } + - match: { inference.models.2.task_type: "RERANK" } + - match: { inference.models.2.count: 1 } + - not_exists: inference.models.2.semantic_text + + - match: { inference.models.3.service: "_all" } + - match: { inference.models.3.task_type: "SPARSE_EMBEDDING" } + - match: { inference.models.3.count: 1 } + - match: { inference.models.3.semantic_text: { field_count: 2, indices_count: 2, inference_id_count: 1} } + + - match: { inference.models.4.service: "_all" } + - match: { inference.models.4.task_type: "TEXT_EMBEDDING" } + - match: { inference.models.4.count: 1 } + - match: { inference.models.4.semantic_text: { field_count: 1, indices_count: 1, inference_id_count: 1} } + + - match: { inference.models.5.service: "_elasticsearch__elser_model_2" } + - match: { inference.models.5.task_type: "SPARSE_EMBEDDING" } + - match: { inference.models.5.count: 1 } + - match: { inference.models.5.semantic_text: { field_count: 2, indices_count: 2, inference_id_count: 1} } + + - match: { inference.models.6.service: "_elasticsearch__multilingual-e5-small" } + - match: { inference.models.6.task_type: "TEXT_EMBEDDING" } + - match: { inference.models.6.count: 1 } + - match: { inference.models.6.semantic_text: { field_count: 1, indices_count: 1, inference_id_count: 1} } + + - match: { inference.models.7.service: "elasticsearch" } + - match: { inference.models.7.task_type: "RERANK" } + - match: { inference.models.7.count: 1 } + - not_exists: inference.models.7.semantic_text + + - match: { inference.models.8.service: "elasticsearch" } + - match: { inference.models.8.task_type: "SPARSE_EMBEDDING" } + - match: { inference.models.8.count: 1 } + - match: { inference.models.8.semantic_text: { field_count: 2, indices_count: 2, inference_id_count: 1} } + + - match: { inference.models.9.service: "elasticsearch" } + - match: { inference.models.9.task_type: "TEXT_EMBEDDING" } + - match: { inference.models.9.count: 1 } + - match: { inference.models.9.semantic_text: { field_count: 1, indices_count: 1, inference_id_count: 1} }