Skip to content

Commit 528d676

Browse files
committed
Record inference API re-routing attributes as part of request metrics
1 parent e706193 commit 528d676

File tree

5 files changed

+124
-40
lines changed

5 files changed

+124
-40
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;
4949

5050
import java.io.IOException;
51+
import java.util.HashMap;
52+
import java.util.Map;
5153
import java.util.Random;
5254
import java.util.concurrent.Executor;
5355
import java.util.concurrent.Flow;
@@ -59,6 +61,7 @@
5961
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
6062
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
6163
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
64+
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes;
6265

6366
/**
6467
* Base class for transport actions that handle inference requests.
@@ -145,7 +148,8 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
145148
}
146149

147150
var service = serviceRegistry.getService(serviceName).get();
148-
var routingDecision = determineRouting(serviceName, request, unparsedModel);
151+
var localNodeId = nodeClient.getLocalNodeId();
152+
var routingDecision = determineRouting(serviceName, request, unparsedModel, localNodeId);
149153

150154
if (routingDecision.currentNodeShouldHandleRequest()) {
151155
var model = service.parsePersistedConfigWithSecrets(
@@ -154,7 +158,7 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
154158
unparsedModel.settings(),
155159
unparsedModel.secrets()
156160
);
157-
inferOnServiceWithMetrics(model, request, service, timer, listener);
161+
inferOnServiceWithMetrics(model, request, service, timer, localNodeId, listener);
158162
} else {
159163
// Reroute request
160164
request.setHasBeenRerouted(true);
@@ -188,7 +192,7 @@ private void validateRequest(Request request, UnparsedModel unparsedModel) {
188192
);
189193
}
190194

191-
private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel) {
195+
private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel, String localNodeId) {
192196
var modelTaskType = unparsedModel.taskType();
193197

194198
// Rerouting not supported or request was already rerouted
@@ -212,7 +216,6 @@ private NodeRoutingDecision determineRouting(String serviceName, Request request
212216
}
213217

214218
var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size()));
215-
String localNodeId = nodeClient.getLocalNodeId();
216219

217220
// The drawn node is the current node
218221
if (nodeToHandleRequest.getId().equals(localNodeId)) {
@@ -260,7 +263,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException {
260263

261264
private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
262265
try {
263-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
266+
Map<String, Object> metricAttributes = new HashMap<>();
267+
metricAttributes.putAll(modelAttributes(model));
268+
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
269+
270+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
264271
} catch (Exception e) {
265272
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
266273
}
@@ -271,6 +278,7 @@ private void inferOnServiceWithMetrics(
271278
Request request,
272279
InferenceService service,
273280
InferenceTimer timer,
281+
String localNodeId,
274282
ActionListener<InferenceAction.Response> listener
275283
) {
276284
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
@@ -279,18 +287,18 @@ private void inferOnServiceWithMetrics(
279287
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
280288
inferenceResults.publisher().subscribe(taskProcessor);
281289

282-
var instrumentedStream = new PublisherWithMetrics(timer, model);
290+
var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
283291
taskProcessor.subscribe(instrumentedStream);
284292

285293
var streamErrorHandler = streamErrorHandler(instrumentedStream);
286294

287295
listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
288296
} else {
289-
recordMetrics(model, timer, null);
297+
recordMetrics(model, timer, request, localNodeId, null);
290298
listener.onResponse(new InferenceAction.Response(inferenceResults));
291299
}
292300
}, e -> {
293-
recordMetrics(model, timer, e);
301+
recordMetrics(model, timer, request, localNodeId, e);
294302
listener.onFailure(e);
295303
}));
296304
}
@@ -299,9 +307,14 @@ protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<Ch
299307
return upstream;
300308
}
301309

302-
private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
310+
private void recordMetrics(Model model, InferenceTimer timer, Request request, String localNodeId, @Nullable Throwable t) {
303311
try {
304-
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, unwrapCause(t)));
312+
Map<String, Object> metricAttributes = new HashMap<>();
313+
metricAttributes.putAll(modelAttributes(model));
314+
metricAttributes.putAll(routingAttributes(request, localNodeId));
315+
metricAttributes.putAll(responseAttributes(unwrapCause(t)));
316+
317+
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
305318
} catch (Exception e) {
306319
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
307320
}
@@ -353,10 +366,14 @@ private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent
353366

354367
private final InferenceTimer timer;
355368
private final Model model;
369+
private final Request request;
370+
private final String localNodeId;
356371

357-
private PublisherWithMetrics(InferenceTimer timer, Model model) {
372+
private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
358373
this.timer = timer;
359374
this.model = model;
375+
this.request = request;
376+
this.localNodeId = localNodeId;
360377
}
361378

362379
@Override
@@ -366,19 +383,19 @@ protected void next(ChunkedToXContent item) {
366383

367384
@Override
368385
public void onError(Throwable throwable) {
369-
recordMetrics(model, timer, throwable);
386+
recordMetrics(model, timer, request, localNodeId, throwable);
370387
super.onError(throwable);
371388
}
372389

373390
@Override
374391
protected void onCancel() {
375-
recordMetrics(model, timer, null);
392+
recordMetrics(model, timer, request, localNodeId, null);
376393
super.onCancel();
377394
}
378395

379396
@Override
380397
public void onComplete() {
381-
recordMetrics(model, timer, null);
398+
recordMetrics(model, timer, request, localNodeId, null);
382399
super.onComplete();
383400
}
384401
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
import org.elasticsearch.telemetry.metric.LongCounter;
1515
import org.elasticsearch.telemetry.metric.LongHistogram;
1616
import org.elasticsearch.telemetry.metric.MeterRegistry;
17+
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
1718

1819
import java.util.Map;
1920
import java.util.Objects;
2021
import java.util.stream.Collectors;
2122
import java.util.stream.Stream;
2223

2324
import static java.util.Map.entry;
24-
import static java.util.stream.Stream.concat;
2525

2626
public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {
2727

@@ -45,49 +45,43 @@ public static InferenceStats create(MeterRegistry meterRegistry) {
4545
);
4646
}
4747

48-
public static Map<String, Object> modelAttributes(Model model) {
49-
return toMap(modelAttributeEntries(model));
48+
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
49+
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
5050
}
5151

52-
private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
52+
public static Map<String, Object> modelAttributes(Model model) {
5353
var stream = Stream.<Map.Entry<String, Object>>builder()
5454
.add(entry("service", model.getConfigurations().getService()))
5555
.add(entry("task_type", model.getTaskType().toString()));
5656
if (model.getServiceSettings().modelId() != null) {
5757
stream.add(entry("model_id", model.getServiceSettings().modelId()));
5858
}
59-
return stream.build();
59+
return toMap(stream.build());
6060
}
6161

62-
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
63-
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
62+
public static Map<String, Object> routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) {
63+
return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest);
6464
}
6565

66-
public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
67-
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
68-
}
69-
70-
public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
66+
public static Map<String, Object> modelAttributes(UnparsedModel model) {
7167
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
7268
.add(entry("service", model.service()))
7369
.add(entry("task_type", model.taskType().toString()))
7470
.build();
7571

76-
return toMap(concat(unknownModelAttributes, errorAttributes(t)));
72+
return toMap(unknownModelAttributes);
7773
}
7874

7975
public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
80-
return toMap(errorAttributes(t));
81-
}
82-
83-
private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
84-
return switch (t) {
85-
case null -> Stream.of(entry("status_code", 200));
76+
var stream = switch (t) {
77+
case null -> Stream.<Map.Entry<String, Object>>of(entry("status_code", 200));
8678
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
8779
.add(entry("status_code", ese.status().getStatus()))
8880
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
8981
.build();
90-
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
82+
default -> Stream.<Map.Entry<String, Object>>of(entry("error.type", t.getClass().getSimpleName()));
9183
};
84+
85+
return toMap(stream);
9286
}
9387
}

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
6262
protected static final String serviceId = "serviceId";
6363
protected final TaskType taskType;
6464
protected static final String inferenceId = "inferenceEntityId";
65+
protected static final String localNodeId = "local-node-id";
6566
protected InferenceServiceRegistry serviceRegistry;
6667
protected InferenceStats inferenceStats;
6768
protected InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
@@ -100,6 +101,7 @@ public void setUp() throws Exception {
100101
);
101102

102103
mockValidLicenseState();
104+
mockNodeClient();
103105
}
104106

105107
protected abstract BaseTransportInferenceAction<Request> createAction(
@@ -135,6 +137,8 @@ public void testMetricsAfterModelRegistryError() {
135137
assertThat(attributes.get("model_id"), nullValue());
136138
assertThat(attributes.get("status_code"), nullValue());
137139
assertThat(attributes.get("error.type"), is(expectedError));
140+
assertThat(attributes.get("rerouted"), nullValue());
141+
assertThat(attributes.get("node_id"), nullValue());
138142
}));
139143
}
140144

@@ -176,6 +180,8 @@ public void testMetricsAfterMissingService() {
176180
assertThat(attributes.get("model_id"), nullValue());
177181
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
178182
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
183+
assertThat(attributes.get("rerouted"), nullValue());
184+
assertThat(attributes.get("node_id"), nullValue());
179185
}));
180186
}
181187

@@ -216,6 +222,8 @@ public void testMetricsAfterUnknownTaskType() {
216222
assertThat(attributes.get("model_id"), nullValue());
217223
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
218224
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
225+
assertThat(attributes.get("rerouted"), nullValue());
226+
assertThat(attributes.get("node_id"), nullValue());
219227
}));
220228
}
221229

@@ -232,6 +240,8 @@ public void testMetricsAfterInferError() {
232240
assertThat(attributes.get("model_id"), nullValue());
233241
assertThat(attributes.get("status_code"), nullValue());
234242
assertThat(attributes.get("error.type"), is(expectedError));
243+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
244+
assertThat(attributes.get("node_id"), is(localNodeId));
235245
}));
236246
}
237247

@@ -254,6 +264,8 @@ public void testMetricsAfterStreamUnsupported() {
254264
assertThat(attributes.get("model_id"), nullValue());
255265
assertThat(attributes.get("status_code"), is(expectedStatus.getStatus()));
256266
assertThat(attributes.get("error.type"), is(expectedError));
267+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
268+
assertThat(attributes.get("node_id"), is(localNodeId));
257269
}));
258270
}
259271

@@ -269,6 +281,8 @@ public void testMetricsAfterInferSuccess() {
269281
assertThat(attributes.get("model_id"), nullValue());
270282
assertThat(attributes.get("status_code"), is(200));
271283
assertThat(attributes.get("error.type"), nullValue());
284+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
285+
assertThat(attributes.get("node_id"), is(localNodeId));
272286
}));
273287
}
274288

@@ -280,6 +294,8 @@ public void testMetricsAfterStreamInferSuccess() {
280294
assertThat(attributes.get("model_id"), nullValue());
281295
assertThat(attributes.get("status_code"), is(200));
282296
assertThat(attributes.get("error.type"), nullValue());
297+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
298+
assertThat(attributes.get("node_id"), is(localNodeId));
283299
}));
284300
}
285301

@@ -296,6 +312,8 @@ public void testMetricsAfterStreamInferFailure() {
296312
assertThat(attributes.get("model_id"), nullValue());
297313
assertThat(attributes.get("status_code"), nullValue());
298314
assertThat(attributes.get("error.type"), is(expectedError));
315+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
316+
assertThat(attributes.get("node_id"), is(localNodeId));
299317
}));
300318
}
301319

@@ -329,6 +347,8 @@ public void onComplete() {
329347
assertThat(attributes.get("model_id"), nullValue());
330348
assertThat(attributes.get("status_code"), is(200));
331349
assertThat(attributes.get("error.type"), nullValue());
350+
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
351+
assertThat(attributes.get("node_id"), is(localNodeId));
332352
}));
333353
}
334354

@@ -404,4 +424,8 @@ protected void mockModelAndServiceRegistry(InferenceService service) {
404424
protected void mockValidLicenseState() {
405425
when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true);
406426
}
427+
428+
private void mockNodeClient(){
429+
when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
430+
}
407431
}

0 commit comments

Comments
 (0)