Skip to content

Commit 9b29009

Browse files
timgreinafoucret
authored andcommitted
[Inference API] Record re-routing attributes as part of inference request metrics (elastic#122350)
Record inference API re-routing attributes as part of request metrics.
1 parent 24e016f commit 9b29009

File tree

5 files changed

+141
-43
lines changed

5 files changed

+141
-43
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
4848

4949
import java.io.IOException;
50+
import java.util.HashMap;
51+
import java.util.Map;
5052
import java.util.Random;
5153
import java.util.concurrent.Executor;
5254
import java.util.concurrent.Flow;
@@ -58,6 +60,7 @@
5860
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
5961
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
6062
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
63+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes;
6164

6265
/**
6366
* Base class for transport actions that handle inference requests.
@@ -138,13 +141,14 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
138141
try {
139142
validateRequest(request, unparsedModel);
140143
} catch (Exception e) {
141-
recordMetrics(unparsedModel, timer, e);
144+
recordRequestDurationMetrics(unparsedModel, timer, e);
142145
listener.onFailure(e);
143146
return;
144147
}
145148

146149
var service = serviceRegistry.getService(serviceName).get();
147-
var routingDecision = determineRouting(serviceName, request, unparsedModel);
150+
var localNodeId = nodeClient.getLocalNodeId();
151+
var routingDecision = determineRouting(serviceName, request, unparsedModel, localNodeId);
148152

149153
if (routingDecision.currentNodeShouldHandleRequest()) {
150154
var model = service.parsePersistedConfigWithSecrets(
@@ -153,7 +157,7 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
153157
unparsedModel.settings(),
154158
unparsedModel.secrets()
155159
);
156-
inferOnServiceWithMetrics(model, request, service, timer, listener);
160+
inferOnServiceWithMetrics(model, request, service, timer, localNodeId, listener);
157161
} else {
158162
// Reroute request
159163
request.setHasBeenRerouted(true);
@@ -187,7 +191,7 @@ private void validateRequest(Request request, UnparsedModel unparsedModel) {
187191
);
188192
}
189193

190-
private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel) {
194+
private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel, String localNodeId) {
191195
var modelTaskType = unparsedModel.taskType();
192196

193197
// Rerouting not supported or request was already rerouted
@@ -211,7 +215,6 @@ private NodeRoutingDecision determineRouting(String serviceName, Request request
211215
}
212216

213217
var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size()));
214-
String localNodeId = nodeClient.getLocalNodeId();
215218

216219
// The drawn node is the current node
217220
if (nodeToHandleRequest.getId().equals(localNodeId)) {
@@ -257,9 +260,13 @@ public InferenceAction.Response read(StreamInput in) throws IOException {
257260
);
258261
}
259262

260-
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
263+
private void recordRequestDurationMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
261264
try {
262-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
265+
Map<String, Object> metricAttributes = new HashMap<>();
266+
metricAttributes.putAll(modelAttributes(model));
267+
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
268+
269+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
263270
} catch (Exception e) {
264271
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
265272
}
@@ -270,9 +277,10 @@ private void inferOnServiceWithMetrics(
270277
Request request,
271278
InferenceService service,
272279
InferenceTimer timer,
280+
String localNodeId,
273281
ActionListener<InferenceAction.Response> listener
274282
) {
275-
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
283+
recordRequestCountMetrics(model, request, localNodeId);
276284
inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> {
277285
if (request.isStreaming()) {
278286
var taskProcessor = streamingTaskManager.<InferenceServiceResults.Result>create(
@@ -281,18 +289,18 @@ private void inferOnServiceWithMetrics(
281289
);
282290
inferenceResults.publisher().subscribe(taskProcessor);
283291

284-
var instrumentedStream = new PublisherWithMetrics(timer, model);
292+
var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
285293
taskProcessor.subscribe(instrumentedStream);
286294

287295
var streamErrorHandler = streamErrorHandler(instrumentedStream);
288296

289297
listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
290298
} else {
291-
recordMetrics(model, timer, null);
299+
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
292300
listener.onResponse(new InferenceAction.Response(inferenceResults));
293301
}
294302
}, e -> {
295-
recordMetrics(model, timer, e);
303+
recordRequestDurationMetrics(model, timer, request, localNodeId, e);
296304
listener.onFailure(e);
297305
}));
298306
}
@@ -301,9 +309,28 @@ protected <T> Flow.Publisher<T> streamErrorHandler(Flow.Processor<T, T> upstream
301309
return upstream;
302310
}
303311

304-
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
312+
private void recordRequestCountMetrics(Model model, Request request, String localNodeId) {
313+
Map<String, Object> requestCountAttributes = new HashMap<>();
314+
requestCountAttributes.putAll(modelAttributes(model));
315+
requestCountAttributes.putAll(routingAttributes(request, localNodeId));
316+
317+
inferenceStats.requestCount().incrementBy(1, requestCountAttributes);
318+
}
319+
320+
private void recordRequestDurationMetrics(
321+
Model model,
322+
InferenceTimer timer,
323+
Request request,
324+
String localNodeId,
325+
@Nullable Throwable t
326+
) {
305327
try {
306-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, unwrapCause(t)));
328+
Map<String, Object> metricAttributes = new HashMap<>();
329+
metricAttributes.putAll(modelAttributes(model));
330+
metricAttributes.putAll(routingAttributes(request, localNodeId));
331+
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
332+
333+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
307334
} catch (Exception e) {
308335
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
309336
}
@@ -355,10 +382,14 @@ private class PublisherWithMetrics extends DelegatingProcessor<InferenceServiceR
355382

356383
private final InferenceTimer timer;
357384
private final Model model;
385+
private final Request request;
386+
private final String localNodeId;
358387

359-
private PublisherWithMetrics(InferenceTimer timer, Model model) {
388+
private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
360389
this.timer = timer;
361390
this.model = model;
391+
this.request = request;
392+
this.localNodeId = localNodeId;
362393
}
363394

364395
@Override
@@ -368,19 +399,19 @@ protected void next(InferenceServiceResults.Result item) {
368399

369400
@Override
370401
public void onError(Throwable throwable) {
371-
recordMetrics(model, timer, throwable);
402+
recordRequestDurationMetrics(model, timer, request, localNodeId, throwable);
372403
super.onError(throwable);
373404
}
374405

375406
@Override
376407
protected void onCancel() {
377-
recordMetrics(model, timer, null);
408+
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
378409
super.onCancel();
379410
}
380411

381412
@Override
382413
public void onComplete() {
383-
recordMetrics(model, timer, null);
414+
recordRequestDurationMetrics(model, timer, request, localNodeId, null);
384415
super.onComplete();
385416
}
386417
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
import org.elasticsearch.telemetry.metric.LongCounter;
1515
import org.elasticsearch.telemetry.metric.LongHistogram;
1616
import org.elasticsearch.telemetry.metric.MeterRegistry;
17+
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
1718

1819
import java.util.Map;
1920
import java.util.Objects;
2021
import java.util.stream.Collectors;
2122
import java.util.stream.Stream;
2223

2324
import static java.util.Map.entry;
24-
import static java.util.stream.Stream.concat;
2525

2626
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
2727

@@ -45,49 +45,43 @@ public static InferenceStats create(MeterRegistry meterRegistry) {
4545
);
4646
}
4747

48-
public static Map<String, Object> modelAttributes(Model model) {
49-
return toMap(modelAttributeEntries(model));
48+
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
49+
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
5050
}
5151

52-
private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
52+
public static Map<String, Object> modelAttributes(Model model) {
5353
var stream = Stream.<Map.Entry<String, Object>>builder()
5454
.add(entry("service", model.getConfigurations().getService()))
5555
.add(entry("task_type", model.getTaskType().toString()));
5656
if (model.getServiceSettings().modelId() != null) {
5757
stream.add(entry("model_id", model.getServiceSettings().modelId()));
5858
}
59-
return stream.build();
59+
return toMap(stream.build());
6060
}
6161

62-
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
63-
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
62+
public static Map<String, Object> routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) {
63+
return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest);
6464
}
6565

66-
public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
67-
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
68-
}
69-
70-
public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
66+
public static Map<String, Object> modelAttributes(UnparsedModel model) {
7167
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
7268
.add(entry("service", model.service()))
7369
.add(entry("task_type", model.taskType().toString()))
7470
.build();
7571

76-
return toMap(concat(unknownModelAttributes, errorAttributes(t)));
72+
return toMap(unknownModelAttributes);
7773
}
7874

7975
public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
80-
return toMap(errorAttributes(t));
81-
}
82-
83-
private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
84-
return switch (t) {
85-
case null -> Stream.of(entry("status_code", 200));
76+
var stream = switch (t) {
77+
case null -> Stream.<Map.Entry<String, Object>>of(entry("status_code", 200));
8678
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
8779
.add(entry("status_code", ese.status().getStatus()))
8880
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
8981
.build();
90-
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
82+
default -> Stream.<Map.Entry<String, Object>>of(entry("error.type", t.getClass().getSimpleName()));
9183
};
84+
85+
return toMap(stream);
9286
}
9387
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
6262
protected static final String serviceId = "serviceId";
6363
protected final TaskType taskType;
6464
protected static final String inferenceId = "inferenceEntityId";
65+
protected static final String localNodeId = "local-node-id";
6566
protected InferenceServiceRegistry serviceRegistry;
6667
protected InferenceStats inferenceStats;
6768
protected InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
@@ -100,6 +101,7 @@ public void setUp() throws Exception {
100101
);
101102

102103
mockValidLicenseState();
104+
mockNodeClient();
103105
}
104106

105107
protected abstract BaseTransportInferenceAction<Request> createAction(
@@ -135,6 +137,8 @@ public void testMetricsAfterModelRegistryError() {
135137
assertThat(attributes.get("model_id"), nullValue());
136138
assertThat(attributes.get("status_code"), nullValue());
137139
assertThat(attributes.get("error.type"), is(expectedError));
140+
assertThat(attributes.get("rerouted"), nullValue());
141+
assertThat(attributes.get("node_id"), nullValue());
138142
}));
139143
}
140144

@@ -176,6 +180,8 @@ public void testMetricsAfterMissingService() {
176180
assertThat(attributes.get("model_id"), nullValue());
177181
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
178182
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
183+
assertThat(attributes.get("rerouted"), nullValue());
184+
assertThat(attributes.get("node_id"), nullValue());
179185
}));
180186
}
181187

@@ -216,6 +222,8 @@ public void testMetricsAfterUnknownTaskType() {
216222
assertThat(attributes.get("model_id"), nullValue());
217223
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
218224
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
225+
assertThat(attributes.get("rerouted"), nullValue());
226+
assertThat(attributes.get("node_id"), nullValue());
219227
}));
220228
}
221229

@@ -232,6 +240,8 @@ public void testMetricsAfterInferError() {
232240
assertThat(attributes.get("model_id"), nullValue());
233241
assertThat(attributes.get("status_code"), nullValue());
234242
assertThat(attributes.get("error.type"), is(expectedError));
243+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
244+
assertThat(attributes.get("node_id"), is(localNodeId));
235245
}));
236246
}
237247

@@ -254,6 +264,8 @@ public void testMetricsAfterStreamUnsupported() {
254264
assertThat(attributes.get("model_id"), nullValue());
255265
assertThat(attributes.get("status_code"), is(expectedStatus.getStatus()));
256266
assertThat(attributes.get("error.type"), is(expectedError));
267+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
268+
assertThat(attributes.get("node_id"), is(localNodeId));
257269
}));
258270
}
259271

@@ -269,6 +281,8 @@ public void testMetricsAfterInferSuccess() {
269281
assertThat(attributes.get("model_id"), nullValue());
270282
assertThat(attributes.get("status_code"), is(200));
271283
assertThat(attributes.get("error.type"), nullValue());
284+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
285+
assertThat(attributes.get("node_id"), is(localNodeId));
272286
}));
273287
}
274288

@@ -280,6 +294,8 @@ public void testMetricsAfterStreamInferSuccess() {
280294
assertThat(attributes.get("model_id"), nullValue());
281295
assertThat(attributes.get("status_code"), is(200));
282296
assertThat(attributes.get("error.type"), nullValue());
297+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
298+
assertThat(attributes.get("node_id"), is(localNodeId));
283299
}));
284300
}
285301

@@ -296,6 +312,8 @@ public void testMetricsAfterStreamInferFailure() {
296312
assertThat(attributes.get("model_id"), nullValue());
297313
assertThat(attributes.get("status_code"), nullValue());
298314
assertThat(attributes.get("error.type"), is(expectedError));
315+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
316+
assertThat(attributes.get("node_id"), is(localNodeId));
299317
}));
300318
}
301319

@@ -329,6 +347,8 @@ public void onComplete() {
329347
assertThat(attributes.get("model_id"), nullValue());
330348
assertThat(attributes.get("status_code"), is(200));
331349
assertThat(attributes.get("error.type"), nullValue());
350+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
351+
assertThat(attributes.get("node_id"), is(localNodeId));
332352
}));
333353
}
334354

@@ -404,4 +424,8 @@ protected void mockModelAndServiceRegistry(InferenceService service) {
404424
protected void mockValidLicenseState() {
405425
when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true);
406426
}
427+
428+
private void mockNodeClient(){
429+
when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
430+
}
407431
}

0 commit comments

Comments
 (0)