Skip to content

Commit 4f79558

Browse files
Address evil edge case
1 parent 86e022e commit 4f79558

File tree

2 files changed

+56
-15
lines changed

2 files changed

+56
-15
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageAction.java

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.cluster.metadata.Metadata;
2020
import org.elasticsearch.cluster.service.ClusterService;
2121
import org.elasticsearch.common.Strings;
22+
import org.elasticsearch.core.Nullable;
2223
import org.elasticsearch.inference.ModelConfigurations;
2324
import org.elasticsearch.inference.TaskType;
2425
import org.elasticsearch.injection.guice.Inject;
@@ -52,6 +53,9 @@ public class TransportInferenceUsageAction extends XPackUsageFeatureTransportAct
5253

5354
private final Logger logger = LogManager.getLogger(TransportInferenceUsageAction.class);
5455

56+
// Some of the default models have optimized variants for linux that will have the following suffix.
57+
private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64";
58+
5559
private final ModelRegistry modelRegistry;
5660
private final Client client;
5761

@@ -190,7 +194,7 @@ private void addStatsForDefaultModels(
190194
) {
191195
Map<String, String> endpointIdToModelId = endpoints.stream()
192196
.filter(endpoint -> endpoint.getServiceSettings().modelId() != null)
193-
.collect(Collectors.toMap(ModelConfigurations::getInferenceEntityId, e -> e.getServiceSettings().modelId()));
197+
.collect(Collectors.toMap(ModelConfigurations::getInferenceEntityId, e -> stripLinuxSuffix(e.getServiceSettings().modelId())));
194198
Map<DefaultModelStatsKey, Long> defaultModelsToEndpointCount = createDefaultStatsKeysWithEndpointCounts(endpoints);
195199
for (Map.Entry<DefaultModelStatsKey, Long> defaultModelStatsKeyToEndpointCount : defaultModelsToEndpointCount.entrySet()) {
196200
DefaultModelStatsKey statKey = defaultModelStatsKeyToEndpointCount.getKey();
@@ -208,12 +212,18 @@ private void addStatsForDefaultModels(
208212
private Map<DefaultModelStatsKey, Long> createDefaultStatsKeysWithEndpointCounts(List<ModelConfigurations> endpoints) {
209213
Set<String> modelIds = endpoints.stream()
210214
.filter(endpoint -> modelRegistry.containsDefaultConfigId(endpoint.getInferenceEntityId()))
211-
.map(endpoint -> endpoint.getServiceSettings().modelId())
215+
.filter(endpoint -> endpoint.getServiceSettings().modelId() != null)
216+
.map(endpoint -> stripLinuxSuffix(endpoint.getServiceSettings().modelId()))
212217
.collect(Collectors.toSet());
213218
return endpoints.stream()
214-
.filter(endpoint -> modelIds.contains(endpoint.getServiceSettings().modelId()))
219+
.filter(endpoint -> endpoint.getServiceSettings().modelId() != null)
220+
.filter(endpoint -> modelIds.contains(stripLinuxSuffix(endpoint.getServiceSettings().modelId())))
215221
.map(
216-
endpoint -> new DefaultModelStatsKey(endpoint.getService(), endpoint.getTaskType(), endpoint.getServiceSettings().modelId())
222+
endpoint -> new DefaultModelStatsKey(
223+
endpoint.getService(),
224+
endpoint.getTaskType(),
225+
stripLinuxSuffix(endpoint.getServiceSettings().modelId())
226+
)
217227
)
218228
.collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
219229
}
@@ -232,22 +242,20 @@ private static Map<String, List<InferenceFieldMetadata>> filterFields(
232242
return filtered;
233243
}
234244

235-
private record DefaultModelStatsKey(String service, TaskType taskType, String modelId) {
245+
@Nullable
246+
private static String stripLinuxSuffix(@Nullable String modelId) {
247+
if (modelId.endsWith(MODEL_ID_LINUX_SUFFIX)) {
248+
return modelId.substring(0, modelId.length() - MODEL_ID_LINUX_SUFFIX.length());
249+
}
250+
return modelId;
251+
}
236252

237-
// Some of the default models have optimized variants for linux that will have the following suffix.
238-
private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64";
253+
private record DefaultModelStatsKey(String service, TaskType taskType, String modelId) {
239254

240255
@Override
241256
public String toString() {
242257
// Inference ids cannot start with '_'. Thus, default stats do to avoid conflicts with user-defined inference ids.
243-
return "_" + service + "_" + stripLinuxSuffix(modelId).replace('.', '_');
244-
}
245-
246-
private static String stripLinuxSuffix(String modelId) {
247-
if (modelId.endsWith(MODEL_ID_LINUX_SUFFIX)) {
248-
return modelId.substring(0, modelId.length() - MODEL_ID_LINUX_SUFFIX.length());
249-
}
250-
return modelId;
258+
return "_" + service + "_" + modelId.replace('.', '_');
251259
}
252260
}
253261

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceUsageActionTests.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,39 @@ public void testGivenDefaultModelWithLinuxSuffix() throws Exception {
253253
assertStats(response, 2, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2)));
254254
}
255255

256+
public void testGivenSameDefaultModelWithAndWithoutLinuxSuffix() throws Exception {
257+
givenInferenceEndpoints(
258+
new ModelConfigurations(".endpoint-001", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001_linux-x86_64")),
259+
new ModelConfigurations("endpoint-002", TaskType.TEXT_EMBEDDING, "eis", mockServiceSettings(".model-id-001"))
260+
);
261+
262+
givenDefaultEndpoints(".endpoint-001");
263+
264+
givenInferenceFields(
265+
Map.of(
266+
"index_1",
267+
List.of(
268+
new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null),
269+
new InferenceFieldMetadata("semantic-2", ".endpoint-001", new String[0], null),
270+
new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null)
271+
),
272+
"index_2",
273+
List.of(
274+
new InferenceFieldMetadata("semantic-1", ".endpoint-001", new String[0], null),
275+
new InferenceFieldMetadata("semantic-2", ".endpoint-001", new String[0], null),
276+
new InferenceFieldMetadata("semantic-3", "endpoint-002", new String[0], null)
277+
)
278+
)
279+
);
280+
281+
XContentSource response = executeAction();
282+
283+
assertThat(response.getValue("models"), hasSize(3));
284+
assertStats(response, 0, new ModelStats("_all", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2)));
285+
assertStats(response, 1, new ModelStats("_eis__model-id-001", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2)));
286+
assertStats(response, 2, new ModelStats("eis", TaskType.TEXT_EMBEDDING, 2, new SemanticTextStats(6, 2, 2)));
287+
}
288+
256289
public void testGivenExternalServiceModelIsNull() throws Exception {
257290
givenInferenceEndpoints(new ModelConfigurations("endpoint-001", TaskType.TEXT_EMBEDDING, "openai", mockServiceSettings(null)));
258291
givenInferenceFields(Map.of("index_1", List.of(new InferenceFieldMetadata("semantic", "endpoint-001", new String[0], null))));

0 commit comments

Comments
 (0)