Skip to content

Commit fa9e19e

Browse files
Extract inferece model stats into its own file (#134634)
This commit extracts the inner class `ModelStats` of `InferenceFeatureSetUsage` into its own file. There are 2 reasons for doing this: - Allows testing of the `InferenceFeatureSetUsage` class - Introduces a `usage` package under `inference` that can be used to accommodate additional usage data.
1 parent 4ca03f7 commit fa9e19e

File tree

6 files changed

+159
-102
lines changed

6 files changed

+159
-102
lines changed

x-pack/plugin/core/src/main/java/module-info.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
exports org.elasticsearch.xpack.core.indexing;
8080
exports org.elasticsearch.xpack.core.inference.action;
8181
exports org.elasticsearch.xpack.core.inference.results;
82+
exports org.elasticsearch.xpack.core.inference.usage;
8283
exports org.elasticsearch.xpack.core.inference;
8384
exports org.elasticsearch.xpack.core.logstash;
8485
exports org.elasticsearch.xpack.core.ml.action;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsage.java

Lines changed: 5 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@
1111
import org.elasticsearch.TransportVersions;
1212
import org.elasticsearch.common.io.stream.StreamInput;
1313
import org.elasticsearch.common.io.stream.StreamOutput;
14-
import org.elasticsearch.common.io.stream.Writeable;
15-
import org.elasticsearch.inference.TaskType;
16-
import org.elasticsearch.xcontent.ToXContentObject;
1714
import org.elasticsearch.xcontent.XContentBuilder;
1815
import org.elasticsearch.xpack.core.XPackFeatureUsage;
1916
import org.elasticsearch.xpack.core.XPackField;
17+
import org.elasticsearch.xpack.core.inference.usage.ModelStats;
2018

2119
import java.io.IOException;
2220
import java.util.Collection;
@@ -25,83 +23,6 @@
2523

2624
public class InferenceFeatureSetUsage extends XPackFeatureUsage {
2725

28-
public static class ModelStats implements ToXContentObject, Writeable {
29-
30-
private final String service;
31-
private final TaskType taskType;
32-
private long count;
33-
34-
public ModelStats(String service, TaskType taskType) {
35-
this(service, taskType, 0L);
36-
}
37-
38-
public ModelStats(String service, TaskType taskType, long count) {
39-
this.service = service;
40-
this.taskType = taskType;
41-
this.count = count;
42-
}
43-
44-
public ModelStats(ModelStats stats) {
45-
this(stats.service, stats.taskType, stats.count);
46-
}
47-
48-
public ModelStats(StreamInput in) throws IOException {
49-
this.service = in.readString();
50-
this.taskType = in.readEnum(TaskType.class);
51-
this.count = in.readLong();
52-
}
53-
54-
public void add() {
55-
count++;
56-
}
57-
58-
public String service() {
59-
return service;
60-
}
61-
62-
public TaskType taskType() {
63-
return taskType;
64-
}
65-
66-
public long count() {
67-
return count;
68-
}
69-
70-
@Override
71-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
72-
builder.startObject();
73-
addXContentFragment(builder, params);
74-
builder.endObject();
75-
return builder;
76-
}
77-
78-
public void addXContentFragment(XContentBuilder builder, Params params) throws IOException {
79-
builder.field("service", service);
80-
builder.field("task_type", taskType.name());
81-
builder.field("count", count);
82-
}
83-
84-
@Override
85-
public void writeTo(StreamOutput out) throws IOException {
86-
out.writeString(service);
87-
out.writeEnum(taskType);
88-
out.writeLong(count);
89-
}
90-
91-
@Override
92-
public boolean equals(Object o) {
93-
if (this == o) return true;
94-
if (o == null || getClass() != o.getClass()) return false;
95-
ModelStats that = (ModelStats) o;
96-
return count == that.count && Objects.equals(service, that.service) && taskType == that.taskType;
97-
}
98-
99-
@Override
100-
public int hashCode() {
101-
return Objects.hash(service, taskType, count);
102-
}
103-
}
104-
10526
public static final InferenceFeatureSetUsage EMPTY = new InferenceFeatureSetUsage(List.of());
10627

10728
private final Collection<ModelStats> modelStats;
@@ -144,4 +65,8 @@ public boolean equals(Object o) {
14465
public int hashCode() {
14566
return Objects.hashCode(modelStats);
14667
}
68+
69+
Collection<ModelStats> modelStats() {
70+
return modelStats;
71+
}
14772
}
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.usage;
9+
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.inference.TaskType;
14+
import org.elasticsearch.xcontent.ToXContentObject;
15+
import org.elasticsearch.xcontent.XContentBuilder;
16+
17+
import java.io.IOException;
18+
import java.util.Objects;
19+
20+
public class ModelStats implements ToXContentObject, Writeable {
21+
22+
private final String service;
23+
private final TaskType taskType;
24+
private long count;
25+
26+
public ModelStats(String service, TaskType taskType) {
27+
this(service, taskType, 0L);
28+
}
29+
30+
public ModelStats(String service, TaskType taskType, long count) {
31+
this.service = service;
32+
this.taskType = taskType;
33+
this.count = count;
34+
}
35+
36+
public ModelStats(ModelStats stats) {
37+
this(stats.service, stats.taskType, stats.count);
38+
}
39+
40+
public ModelStats(StreamInput in) throws IOException {
41+
this.service = in.readString();
42+
this.taskType = in.readEnum(TaskType.class);
43+
this.count = in.readLong();
44+
}
45+
46+
public void add() {
47+
count++;
48+
}
49+
50+
public String service() {
51+
return service;
52+
}
53+
54+
public TaskType taskType() {
55+
return taskType;
56+
}
57+
58+
public long count() {
59+
return count;
60+
}
61+
62+
@Override
63+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
64+
builder.startObject();
65+
addXContentFragment(builder, params);
66+
builder.endObject();
67+
return builder;
68+
}
69+
70+
public void addXContentFragment(XContentBuilder builder, Params params) throws IOException {
71+
builder.field("service", service);
72+
builder.field("task_type", taskType.name());
73+
builder.field("count", count);
74+
}
75+
76+
@Override
77+
public void writeTo(StreamOutput out) throws IOException {
78+
out.writeString(service);
79+
out.writeEnum(taskType);
80+
out.writeLong(count);
81+
}
82+
83+
@Override
84+
public boolean equals(Object o) {
85+
if (this == o) return true;
86+
if (o == null || getClass() != o.getClass()) return false;
87+
ModelStats that = (ModelStats) o;
88+
return count == that.count && Objects.equals(service, that.service) && taskType == that.taskType;
89+
}
90+
91+
@Override
92+
public int hashCode() {
93+
return Objects.hash(service, taskType, count);
94+
}
95+
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/InferenceFeatureSetUsageTests.java

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,35 @@
77

88
package org.elasticsearch.xpack.core.inference;
99

10-
import com.carrotsearch.randomizedtesting.generators.RandomStrings;
11-
1210
import org.elasticsearch.common.io.stream.Writeable;
13-
import org.elasticsearch.inference.TaskType;
1411
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12+
import org.elasticsearch.xpack.core.inference.usage.ModelStats;
13+
import org.elasticsearch.xpack.core.inference.usage.ModelStatsTests;
1514

1615
import java.io.IOException;
16+
import java.util.ArrayList;
17+
import java.util.List;
1718

18-
public class InferenceFeatureSetUsageTests extends AbstractWireSerializingTestCase<InferenceFeatureSetUsage.ModelStats> {
19+
public class InferenceFeatureSetUsageTests extends AbstractWireSerializingTestCase<InferenceFeatureSetUsage> {
1920

2021
@Override
21-
protected Writeable.Reader<InferenceFeatureSetUsage.ModelStats> instanceReader() {
22-
return InferenceFeatureSetUsage.ModelStats::new;
22+
protected Writeable.Reader<InferenceFeatureSetUsage> instanceReader() {
23+
return InferenceFeatureSetUsage::new;
2324
}
2425

2526
@Override
26-
protected InferenceFeatureSetUsage.ModelStats createTestInstance() {
27-
RandomStrings.randomAsciiLettersOfLength(random(), 10);
28-
return new InferenceFeatureSetUsage.ModelStats(
29-
randomIdentifier(),
30-
TaskType.values()[randomInt(TaskType.values().length - 1)],
31-
randomInt(10)
32-
);
27+
protected InferenceFeatureSetUsage createTestInstance() {
28+
return new InferenceFeatureSetUsage(randomList(10, ModelStatsTests::createRandomInstance));
3329
}
3430

3531
@Override
36-
protected InferenceFeatureSetUsage.ModelStats mutateInstance(InferenceFeatureSetUsage.ModelStats modelStats) throws IOException {
37-
InferenceFeatureSetUsage.ModelStats newModelStats = new InferenceFeatureSetUsage.ModelStats(modelStats);
38-
newModelStats.add();
39-
return newModelStats;
32+
protected InferenceFeatureSetUsage mutateInstance(InferenceFeatureSetUsage instance) throws IOException {
33+
List<ModelStats> mutatedModelStats = new ArrayList<>(instance.modelStats());
34+
if (mutatedModelStats.isEmpty()) {
35+
mutatedModelStats.add(ModelStatsTests.createRandomInstance());
36+
} else {
37+
mutatedModelStats.remove(randomIntBetween(0, mutatedModelStats.size() - 1));
38+
}
39+
return new InferenceFeatureSetUsage(mutatedModelStats);
4040
}
4141
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.inference.usage;
9+
10+
import org.elasticsearch.common.io.stream.Writeable;
11+
import org.elasticsearch.inference.TaskType;
12+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
13+
14+
import java.io.IOException;
15+
16+
public class ModelStatsTests extends AbstractWireSerializingTestCase<ModelStats> {
17+
18+
@Override
19+
protected Writeable.Reader<ModelStats> instanceReader() {
20+
return ModelStats::new;
21+
}
22+
23+
@Override
24+
protected ModelStats createTestInstance() {
25+
return createRandomInstance();
26+
}
27+
28+
@Override
29+
protected ModelStats mutateInstance(ModelStats modelStats) throws IOException {
30+
ModelStats newModelStats = new ModelStats(modelStats);
31+
newModelStats.add();
32+
return newModelStats;
33+
}
34+
35+
public static ModelStats createRandomInstance() {
36+
return new ModelStats(randomIdentifier(), TaskType.values()[randomInt(TaskType.values().length - 1)], randomInt(10));
37+
}
38+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.elasticsearch.xpack.core.action.XPackUsageFeatureTransportAction;
2929
import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage;
3030
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
31+
import org.elasticsearch.xpack.core.inference.usage.ModelStats;
3132

3233
import java.util.Map;
3334
import java.util.TreeMap;
@@ -61,13 +62,10 @@ protected void localClusterStateOperation(
6162
) {
6263
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false);
6364
client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, ActionListener.wrap(response -> {
64-
Map<String, InferenceFeatureSetUsage.ModelStats> stats = new TreeMap<>();
65+
Map<String, ModelStats> stats = new TreeMap<>();
6566
for (ModelConfigurations model : response.getEndpoints()) {
6667
String statKey = model.getService() + ":" + model.getTaskType().name();
67-
InferenceFeatureSetUsage.ModelStats stat = stats.computeIfAbsent(
68-
statKey,
69-
key -> new InferenceFeatureSetUsage.ModelStats(model.getService(), model.getTaskType())
70-
);
68+
ModelStats stat = stats.computeIfAbsent(statKey, key -> new ModelStats(model.getService(), model.getTaskType()));
7169
stat.add();
7270
}
7371
InferenceFeatureSetUsage usage = new InferenceFeatureSetUsage(stats.values());

0 commit comments

Comments
 (0)