Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,30 +102,92 @@ public int hashCode() {
}
}

public static final InferenceFeatureSetUsage EMPTY = new InferenceFeatureSetUsage(List.of());
public static class SemanticTextStats implements ToXContentObject, Writeable {
private final Long totalFieldCount;
private final Long indexCount;
private final Long denseFieldCount;
private final Long sparseFieldCount;
private final Long denseInferenceIdCount;
private final Long sparseInferenceIdCount;

public SemanticTextStats(
Long totalFieldCount,
Long indexCount,
Long sparseFieldCount,
Long denseFieldCount,
Long denseInferenceIdCount,
Long sparseInferenceIdCount
) {
this.totalFieldCount = totalFieldCount;
this.indexCount = indexCount;
this.sparseFieldCount = sparseFieldCount;
this.denseFieldCount = denseFieldCount;
this.denseInferenceIdCount = denseInferenceIdCount;
this.sparseInferenceIdCount = sparseInferenceIdCount;
}

public SemanticTextStats(StreamInput in) throws IOException {
this.totalFieldCount = in.readLong();
this.indexCount = in.readLong();
this.denseFieldCount = in.readLong();
this.denseInferenceIdCount = in.readLong();
this.sparseInferenceIdCount = in.readLong();
this.sparseFieldCount = in.readLong();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeLong(totalFieldCount);
out.writeLong(indexCount);
out.writeLong(denseFieldCount);
out.writeLong(denseInferenceIdCount);
out.writeLong(sparseInferenceIdCount);
out.writeLong(sparseFieldCount);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field("total_fields", totalFieldCount);
builder.field("indices", indexCount);
builder.field("dense_fields", denseFieldCount);
builder.field("dense_inference_ids", denseInferenceIdCount);
builder.field("sparse_fields", sparseFieldCount);
builder.field("sparse_inference_ids", sparseInferenceIdCount);
builder.endObject();
return builder;
}
}

public static final InferenceFeatureSetUsage EMPTY = new InferenceFeatureSetUsage(List.of(), null);

private final Collection<ModelStats> modelStats;
private final SemanticTextStats semanticTextStats;

public InferenceFeatureSetUsage(Collection<ModelStats> modelStats) {
public InferenceFeatureSetUsage(Collection<ModelStats> modelStats, SemanticTextStats semanticTextStats) {
super(XPackField.INFERENCE, true, true);
this.modelStats = modelStats;
this.semanticTextStats = semanticTextStats;
}

public InferenceFeatureSetUsage(StreamInput in) throws IOException {
super(in);
this.modelStats = in.readCollectionAsList(ModelStats::new);
this.semanticTextStats = new SemanticTextStats(in);
}

@Override
protected void innerXContent(XContentBuilder builder, Params params) throws IOException {
super.innerXContent(builder, params);
builder.xContentList("models", modelStats);
builder.field("semantic_text", semanticTextStats);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeCollection(modelStats);
semanticTextStats.writeTo(out);
}

@Override
Expand All @@ -137,11 +199,11 @@ public TransportVersion getMinimalSupportedVersion() {
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
InferenceFeatureSetUsage that = (InferenceFeatureSetUsage) o;
return Objects.equals(modelStats, that.modelStats);
return Objects.equals(modelStats, that.modelStats) && Objects.equals(semanticTextStats, that.semanticTextStats);
}

@Override
public int hashCode() {
return Objects.hashCode(modelStats);
return Objects.hash(modelStats, semanticTextStats);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.ModelConfigurations;
Expand All @@ -29,6 +31,8 @@
import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;

Expand Down Expand Up @@ -62,19 +66,74 @@ protected void localClusterStateOperation(
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false);
client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, ActionListener.wrap(response -> {
Map<String, InferenceFeatureSetUsage.ModelStats> stats = new TreeMap<>();
for (ModelConfigurations model : response.getEndpoints()) {
List<ModelConfigurations> endpoints = response.getEndpoints();
for (ModelConfigurations model : endpoints) {
String statKey = model.getService() + ":" + model.getTaskType().name();
InferenceFeatureSetUsage.ModelStats stat = stats.computeIfAbsent(
statKey,
key -> new InferenceFeatureSetUsage.ModelStats(model.getService(), model.getTaskType())
);
stat.add();
}
InferenceFeatureSetUsage usage = new InferenceFeatureSetUsage(stats.values());

InferenceFeatureSetUsage usage = new InferenceFeatureSetUsage(
stats.values(),
getSemanticTextStats(state.getMetadata().indicesAllProjects(), endpoints)
);
listener.onResponse(new XPackUsageFeatureResponse(usage));
}, e -> {
logger.warn(Strings.format("Retrieving inference usage failed with error: %s", e.getMessage()), e);
listener.onResponse(new XPackUsageFeatureResponse(InferenceFeatureSetUsage.EMPTY));
}));
}

private static InferenceFeatureSetUsage.SemanticTextStats getSemanticTextStats(
Iterable<IndexMetadata> indicesMetadata,
List<ModelConfigurations> modelConfigurations
) {
long fieldCount = 0;
long indexCount = 0;

Map<String, Long> inferenceIdsCounts = new HashMap<>();

for (IndexMetadata indexMetadata : indicesMetadata) {
Map<String, InferenceFieldMetadata> inferenceFields = indexMetadata.getInferenceFields();

fieldCount += inferenceFields.size();
indexCount += inferenceFields.isEmpty() ? 0 : 1;

inferenceFields.forEach((fieldName, inferenceFieldMetadata) -> {
String inferenceId = inferenceFieldMetadata.getInferenceId();
inferenceIdsCounts.compute(inferenceId, (k, v) -> v == null ? 1 : v + 1);
});
}

long sparseFieldsCount = 0;
long denseFieldsCount = 0;
long denseInferenceIdCount = 0;
long sparseInferenceIdCount = 0;
for (ModelConfigurations model : modelConfigurations) {
String inferenceId = model.getInferenceEntityId();

if (inferenceIdsCounts.containsKey(inferenceId) == false) {
continue;
}
if (model.getTaskType() == TaskType.SPARSE_EMBEDDING) {
sparseFieldsCount += inferenceIdsCounts.get(inferenceId);
sparseInferenceIdCount += 1;
} else {
denseFieldsCount += inferenceIdsCounts.get(inferenceId);
denseInferenceIdCount += 1;
}
}

return new InferenceFeatureSetUsage.SemanticTextStats(
fieldCount,
indexCount,
sparseFieldsCount,
denseFieldsCount,
denseInferenceIdCount,
sparseInferenceIdCount
);
}
}