Skip to content

Commit 809631b

Browse files
Only add semantic_text stats if task_type is compatible
1 parent 6f3b519 commit 809631b

File tree

7 files changed

+140
-101
lines changed

7 files changed

+140
-101
lines changed

server/src/main/java/org/elasticsearch/inference/TaskType.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import java.util.Objects;
2222

2323
public enum TaskType implements Writeable {
24-
TEXT_EMBEDDING,
25-
SPARSE_EMBEDDING,
24+
TEXT_EMBEDDING(true),
25+
SPARSE_EMBEDDING(true),
2626
RERANK,
2727
COMPLETION,
2828
ANY {
@@ -52,6 +52,16 @@ public static TaskType fromStringOrStatusException(String name) {
5252
}
5353
}
5454

55+
private final boolean isCompatibleWithSemanticText;
56+
57+
TaskType(boolean isCompatibleWithSemanticText) {
58+
this.isCompatibleWithSemanticText = isCompatibleWithSemanticText;
59+
}
60+
61+
TaskType() {
62+
this(false);
63+
}
64+
5565
/**
5666
* Return true if the {@code other} is the {@link #ANY} type
5767
* or the same as this.
@@ -62,6 +72,14 @@ public boolean isAnyOrSame(TaskType other) {
6272
return other == TaskType.ANY || other == this;
6373
}
6474

75+
/**
76+
* Returns true if this task type is compatible with semantic text.
77+
* @return True if this task type is compatible with semantic text.
78+
*/
79+
public boolean isCompatibleWithSemanticText() {
80+
return isCompatibleWithSemanticText;
81+
}
82+
6583
@Override
6684
public String toString() {
6785
return name().toLowerCase(Locale.ROOT);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/usage/ModelStats.java

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.common.io.stream.StreamInput;
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.core.Nullable;
1415
import org.elasticsearch.features.NodeFeature;
1516
import org.elasticsearch.inference.TaskType;
1617
import org.elasticsearch.xcontent.ToXContentObject;
@@ -30,31 +31,32 @@ public class ModelStats implements ToXContentObject, Writeable {
3031
private final String service;
3132
private final TaskType taskType;
3233
private long count;
34+
@Nullable
3335
private final SemanticTextStats semanticTextStats;
3436

3537
public ModelStats(String service, TaskType taskType) {
36-
this(service, taskType, 0L, new SemanticTextStats());
38+
this(service, taskType, 0L);
3739
}
3840

3941
public ModelStats(String service, TaskType taskType, long count) {
40-
this(service, taskType, count, new SemanticTextStats());
42+
this(service, taskType, count, taskType.isCompatibleWithSemanticText() ? new SemanticTextStats() : null);
4143
}
4244

43-
public ModelStats(String service, TaskType taskType, long count, SemanticTextStats semanticTextStats) {
45+
public ModelStats(String service, TaskType taskType, long count, @Nullable SemanticTextStats semanticTextStats) {
4446
this.service = service;
4547
this.taskType = taskType;
4648
this.count = count;
47-
this.semanticTextStats = Objects.requireNonNull(semanticTextStats);
49+
this.semanticTextStats = semanticTextStats;
4850
}
4951

5052
public ModelStats(StreamInput in) throws IOException {
5153
this.service = in.readString();
5254
this.taskType = in.readEnum(TaskType.class);
5355
this.count = in.readLong();
5456
if (in.getTransportVersion().supports(INFERENCE_TELEMETRY_ADDED_SEMANTIC_TEXT_STATS)) {
55-
this.semanticTextStats = new SemanticTextStats(in);
57+
this.semanticTextStats = in.readOptional(SemanticTextStats::new);
5658
} else {
57-
semanticTextStats = new SemanticTextStats();
59+
this.semanticTextStats = null;
5860
}
5961
}
6062

@@ -74,6 +76,7 @@ public long count() {
7476
return count;
7577
}
7678

79+
@Nullable
7780
public SemanticTextStats semanticTextStats() {
7881
return semanticTextStats;
7982
}
@@ -90,7 +93,9 @@ public void addXContentFragment(XContentBuilder builder, Params params) throws I
9093
builder.field("service", service);
9194
builder.field("task_type", taskType.name());
9295
builder.field("count", count);
93-
builder.field("semantic_text", semanticTextStats);
96+
if (semanticTextStats != null) {
97+
builder.field("semantic_text", semanticTextStats);
98+
}
9499
}
95100

96101
@Override
@@ -99,7 +104,7 @@ public void writeTo(StreamOutput out) throws IOException {
99104
out.writeEnum(taskType);
100105
out.writeLong(count);
101106
if (out.getTransportVersion().supports(INFERENCE_TELEMETRY_ADDED_SEMANTIC_TEXT_STATS)) {
102-
semanticTextStats.writeTo(out);
107+
out.writeOptionalWriteable(semanticTextStats);
103108
}
104109
}
105110

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/usage/ModelStatsTests.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,18 +69,19 @@ public void testAdd() {
6969
}
7070

7171
public static ModelStats createRandomInstance() {
72+
TaskType taskType = randomValueOtherThan(TaskType.ANY, () -> randomFrom(TaskType.values()));
7273
return new ModelStats(
7374
randomIdentifier(),
74-
randomFrom(TaskType.values()),
75+
taskType,
7576
randomLong(),
76-
SemanticTextStatsTests.createRandomInstance()
77+
taskType.isCompatibleWithSemanticText() ? SemanticTextStatsTests.createRandomInstance() : null
7778
);
7879
}
7980

8081
@Override
8182
protected ModelStats mutateInstanceForVersion(ModelStats instance, TransportVersion version) {
8283
if (version.supports(ModelStats.INFERENCE_TELEMETRY_ADDED_SEMANTIC_TEXT_STATS) == false) {
83-
return new ModelStats(instance.service(), instance.taskType(), instance.count(), new SemanticTextStats());
84+
return new ModelStats(instance.service(), instance.taskType(), instance.count(), null);
8485
}
8586
return instance;
8687
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ private InferenceFeatureSetUsage collectUsage(List<ModelConfigurations> endpoint
9595
mapInferenceFieldsByIndexServiceAndTask(indicesMetadata, endpoints);
9696
Map<String, ModelStats> endpointStats = new TreeMap<>();
9797
addStatsByServiceAndTask(inferenceFieldsByIndexServiceAndTask, endpoints, endpointStats);
98-
addStatsForDefaultModels(inferenceFieldsByIndexServiceAndTask, endpoints, endpointStats);
98+
addStatsForDefaultModelsCompatibleWithSemanticText(inferenceFieldsByIndexServiceAndTask, endpoints, endpointStats);
9999
return new InferenceFeatureSetUsage(endpointStats.values());
100100
}
101101

@@ -159,10 +159,10 @@ private static void addStatsByServiceAndTask(
159159
endpointStats.get(serviceAndTaskType.toString())
160160
)
161161
);
162-
addTopLevelSemanticTextStatsByTask(inferenceFieldsByIndexServiceAndTask, endpointStats);
162+
addTopLevelStatsByTask(inferenceFieldsByIndexServiceAndTask, endpointStats);
163163
}
164164

165-
private static void addTopLevelSemanticTextStatsByTask(
165+
private static void addTopLevelStatsByTask(
166166
Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldsByIndexServiceAndTask,
167167
Map<String, ModelStats> endpointStats
168168
) {
@@ -174,14 +174,20 @@ private static void addTopLevelSemanticTextStatsByTask(
174174
new ServiceAndTaskType(Metadata.ALL, taskType).toString(),
175175
key -> new ModelStats(Metadata.ALL, taskType)
176176
);
177-
Map<String, List<InferenceFieldMetadata>> inferenceFieldsByIndex = inferenceFieldsByIndexServiceAndTask.entrySet()
178-
.stream()
179-
.filter(e -> e.getKey().taskType == taskType)
180-
.flatMap(m -> m.getValue().entrySet().stream())
181-
.collect(
182-
Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (l1, l2) -> Stream.concat(l1.stream(), l2.stream()).toList())
183-
);
184-
addSemanticTextStats(inferenceFieldsByIndex, allStatsForTaskType);
177+
if (taskType.isCompatibleWithSemanticText()) {
178+
Map<String, List<InferenceFieldMetadata>> inferenceFieldsByIndex = inferenceFieldsByIndexServiceAndTask.entrySet()
179+
.stream()
180+
.filter(e -> e.getKey().taskType == taskType)
181+
.flatMap(m -> m.getValue().entrySet().stream())
182+
.collect(
183+
Collectors.toMap(
184+
Map.Entry::getKey,
185+
Map.Entry::getValue,
186+
(l1, l2) -> Stream.concat(l1.stream(), l2.stream()).toList()
187+
)
188+
);
189+
addSemanticTextStats(inferenceFieldsByIndex, allStatsForTaskType);
190+
}
185191
}
186192
}
187193

@@ -196,20 +202,21 @@ private static void addSemanticTextStats(Map<String, List<InferenceFieldMetadata
196202
}
197203

198204
/**
199-
* Adds stats for default models. In particular, default models are considered models that are
200-
* associated with default inference endpoints as per the {@code ModelRegistry}. The service name
201-
* for default model stats is "_{service}_{modelId}". Each of those stats contains usage for all
202-
* endpoints that use that model, including non-default endpoints.
205+
* Adds stats for default models that are compatible with semantic_text.
206+
* In particular, default models are considered models that are associated with default inference
207+
* endpoints as per the {@code ModelRegistry}. The service name for default model stats is "_{service}_{modelId}".
208+
* Each of those stats contains usage for all endpoints that use that model, including non-default endpoints.
203209
*/
204-
private void addStatsForDefaultModels(
210+
private void addStatsForDefaultModelsCompatibleWithSemanticText(
205211
Map<ServiceAndTaskType, Map<String, List<InferenceFieldMetadata>>> inferenceFieldsByIndexServiceAndTask,
206212
List<ModelConfigurations> endpoints,
207213
Map<String, ModelStats> endpointStats
208214
) {
209215
Map<String, String> endpointIdToModelId = endpoints.stream()
210216
.filter(endpoint -> endpoint.getServiceSettings().modelId() != null)
211217
.collect(Collectors.toMap(ModelConfigurations::getInferenceEntityId, e -> stripLinuxSuffix(e.getServiceSettings().modelId())));
212-
Map<DefaultModelStatsKey, Long> defaultModelsToEndpointCount = createDefaultStatsKeysWithEndpointCounts(endpoints);
218+
Map<DefaultModelStatsKey, Long> defaultModelsToEndpointCount =
219+
createStatsKeysWithEndpointCountsForDefaultModelsCompatibleWithSemanticText(endpoints);
213220
for (Map.Entry<DefaultModelStatsKey, Long> defaultModelStatsKeyToEndpointCount : defaultModelsToEndpointCount.entrySet()) {
214221
DefaultModelStatsKey statKey = defaultModelStatsKeyToEndpointCount.getKey();
215222
Map<String, List<InferenceFieldMetadata>> fieldsByIndex = inferenceFieldsByIndexServiceAndTask.getOrDefault(
@@ -225,11 +232,14 @@ private void addStatsForDefaultModels(
225232
}
226233
}
227234

228-
private Map<DefaultModelStatsKey, Long> createDefaultStatsKeysWithEndpointCounts(List<ModelConfigurations> endpoints) {
235+
private Map<DefaultModelStatsKey, Long> createStatsKeysWithEndpointCountsForDefaultModelsCompatibleWithSemanticText(
236+
List<ModelConfigurations> endpoints
237+
) {
229238
// We consider models to be default if they are associated with a default inference endpoint.
230239
// Note that endpoints could have a null model id, in which case we don't consider them default as this
231240
// may only happen for external services.
232241
Set<String> modelIds = endpoints.stream()
242+
.filter(endpoint -> endpoint.getTaskType().isCompatibleWithSemanticText())
233243
.filter(endpoint -> modelRegistry.containsDefaultConfigId(endpoint.getInferenceEntityId()))
234244
.filter(endpoint -> endpoint.getServiceSettings().modelId() != null)
235245
.map(endpoint -> stripLinuxSuffix(endpoint.getServiceSettings().modelId()))

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/TaskTypeTests.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import org.elasticsearch.test.ESTestCase;
1313
import org.hamcrest.Matchers;
1414

15+
import static org.hamcrest.core.Is.is;
16+
1517
public class TaskTypeTests extends ESTestCase {
1618

1719
public void testFromStringOrStatusException() {
@@ -24,4 +26,12 @@ public void testFromStringOrStatusException() {
2426
assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY));
2527
}
2628

29+
public void testIsCompatibleWithSemanticText() {
30+
assertThat(TaskType.ANY.isCompatibleWithSemanticText(), is(false));
31+
assertThat(TaskType.CHAT_COMPLETION.isCompatibleWithSemanticText(), is(false));
32+
assertThat(TaskType.COMPLETION.isCompatibleWithSemanticText(), is(false));
33+
assertThat(TaskType.RERANK.isCompatibleWithSemanticText(), is(false));
34+
assertThat(TaskType.TEXT_EMBEDDING.isCompatibleWithSemanticText(), is(true));
35+
assertThat(TaskType.SPARSE_EMBEDDING.isCompatibleWithSemanticText(), is(true));
36+
}
2737
}

0 commit comments

Comments
 (0)