1919import org .elasticsearch .cluster .metadata .Metadata ;
2020import org .elasticsearch .cluster .service .ClusterService ;
2121import org .elasticsearch .common .Strings ;
22+ import org .elasticsearch .core .Nullable ;
2223import org .elasticsearch .inference .ModelConfigurations ;
2324import org .elasticsearch .inference .TaskType ;
2425import org .elasticsearch .injection .guice .Inject ;
@@ -52,6 +53,9 @@ public class TransportInferenceUsageAction extends XPackUsageFeatureTransportAct
5253
5354 private final Logger logger = LogManager .getLogger (TransportInferenceUsageAction .class );
5455
56+ // Some of the default models have optimized variants for linux that will have the following suffix.
57+ private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64" ;
58+
5559 private final ModelRegistry modelRegistry ;
5660 private final Client client ;
5761
@@ -190,7 +194,7 @@ private void addStatsForDefaultModels(
190194 ) {
191195 Map <String , String > endpointIdToModelId = endpoints .stream ()
192196 .filter (endpoint -> endpoint .getServiceSettings ().modelId () != null )
193- .collect (Collectors .toMap (ModelConfigurations ::getInferenceEntityId , e -> e .getServiceSettings ().modelId ()));
197+ .collect (Collectors .toMap (ModelConfigurations ::getInferenceEntityId , e -> stripLinuxSuffix ( e .getServiceSettings ().modelId () )));
194198 Map <DefaultModelStatsKey , Long > defaultModelsToEndpointCount = createDefaultStatsKeysWithEndpointCounts (endpoints );
195199 for (Map .Entry <DefaultModelStatsKey , Long > defaultModelStatsKeyToEndpointCount : defaultModelsToEndpointCount .entrySet ()) {
196200 DefaultModelStatsKey statKey = defaultModelStatsKeyToEndpointCount .getKey ();
@@ -208,12 +212,18 @@ private void addStatsForDefaultModels(
208212 private Map <DefaultModelStatsKey , Long > createDefaultStatsKeysWithEndpointCounts (List <ModelConfigurations > endpoints ) {
209213 Set <String > modelIds = endpoints .stream ()
210214 .filter (endpoint -> modelRegistry .containsDefaultConfigId (endpoint .getInferenceEntityId ()))
211- .map (endpoint -> endpoint .getServiceSettings ().modelId ())
215+ .filter (endpoint -> endpoint .getServiceSettings ().modelId () != null )
216+ .map (endpoint -> stripLinuxSuffix (endpoint .getServiceSettings ().modelId ()))
212217 .collect (Collectors .toSet ());
213218 return endpoints .stream ()
214- .filter (endpoint -> modelIds .contains (endpoint .getServiceSettings ().modelId ()))
219+ .filter (endpoint -> endpoint .getServiceSettings ().modelId () != null )
220+ .filter (endpoint -> modelIds .contains (stripLinuxSuffix (endpoint .getServiceSettings ().modelId ())))
215221 .map (
216- endpoint -> new DefaultModelStatsKey (endpoint .getService (), endpoint .getTaskType (), endpoint .getServiceSettings ().modelId ())
222+ endpoint -> new DefaultModelStatsKey (
223+ endpoint .getService (),
224+ endpoint .getTaskType (),
225+ stripLinuxSuffix (endpoint .getServiceSettings ().modelId ())
226+ )
217227 )
218228 .collect (Collectors .groupingBy (Function .identity (), Collectors .counting ()));
219229 }
@@ -232,22 +242,20 @@ private static Map<String, List<InferenceFieldMetadata>> filterFields(
232242 return filtered ;
233243 }
234244
235- private record DefaultModelStatsKey (String service , TaskType taskType , String modelId ) {
245+ @ Nullable
246+ private static String stripLinuxSuffix (@ Nullable String modelId ) {
247+ if (modelId .endsWith (MODEL_ID_LINUX_SUFFIX )) {
248+ return modelId .substring (0 , modelId .length () - MODEL_ID_LINUX_SUFFIX .length ());
249+ }
250+ return modelId ;
251+ }
236252
237- // Some of the default models have optimized variants for linux that will have the following suffix.
238- private static final String MODEL_ID_LINUX_SUFFIX = "_linux-x86_64" ;
253+ private record DefaultModelStatsKey (String service , TaskType taskType , String modelId ) {
239254
240255 @ Override
241256 public String toString () {
242257 // Inference ids cannot start with '_'. Thus, default stats do to avoid conflicts with user-defined inference ids.
243- return "_" + service + "_" + stripLinuxSuffix (modelId ).replace ('.' , '_' );
244- }
245-
246- private static String stripLinuxSuffix (String modelId ) {
247- if (modelId .endsWith (MODEL_ID_LINUX_SUFFIX )) {
248- return modelId .substring (0 , modelId .length () - MODEL_ID_LINUX_SUFFIX .length ());
249- }
250- return modelId ;
258+ return "_" + service + "_" + modelId .replace ('.' , '_' );
251259 }
252260 }
253261
0 commit comments