Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import org.elasticsearch.xpack.inference.telemetry.InferenceTimer;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
Expand All @@ -59,6 +61,7 @@
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.routingAttributes;

/**
* Base class for transport actions that handle inference requests.
Expand Down Expand Up @@ -145,7 +148,8 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
}

var service = serviceRegistry.getService(serviceName).get();
var routingDecision = determineRouting(serviceName, request, unparsedModel);
var localNodeId = nodeClient.getLocalNodeId();
var routingDecision = determineRouting(serviceName, request, unparsedModel, localNodeId);

if (routingDecision.currentNodeShouldHandleRequest()) {
var model = service.parsePersistedConfigWithSecrets(
Expand All @@ -154,7 +158,7 @@ protected void doExecute(Task task, Request request, ActionListener<InferenceAct
unparsedModel.settings(),
unparsedModel.secrets()
);
inferOnServiceWithMetrics(model, request, service, timer, listener);
inferOnServiceWithMetrics(model, request, service, timer, localNodeId, listener);
} else {
// Reroute request
request.setHasBeenRerouted(true);
Expand Down Expand Up @@ -188,7 +192,7 @@ private void validateRequest(Request request, UnparsedModel unparsedModel) {
);
}

private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel) {
private NodeRoutingDecision determineRouting(String serviceName, Request request, UnparsedModel unparsedModel, String localNodeId) {
var modelTaskType = unparsedModel.taskType();

// Rerouting not supported or request was already rerouted
Expand All @@ -212,7 +216,6 @@ private NodeRoutingDecision determineRouting(String serviceName, Request request
}

var nodeToHandleRequest = responsibleNodes.get(random.nextInt(responsibleNodes.size()));
String localNodeId = nodeClient.getLocalNodeId();

// The drawn node is the current node
if (nodeToHandleRequest.getId().equals(localNodeId)) {
Expand Down Expand Up @@ -260,7 +263,11 @@ public InferenceAction.Response read(StreamInput in) throws IOException {

private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t));
Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAttributes(model));
metricAttributes.putAll(responseAttributes(unwrapCause(t)));

inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics");
}
Expand All @@ -271,6 +278,7 @@ private void inferOnServiceWithMetrics(
Request request,
InferenceService service,
InferenceTimer timer,
String localNodeId,
ActionListener<InferenceAction.Response> listener
) {
inferenceStats.requestCount().incrementBy(1, modelAttributes(model));
Expand All @@ -279,18 +287,18 @@ private void inferOnServiceWithMetrics(
var taskProcessor = streamingTaskManager.<ChunkedToXContent>create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION);
inferenceResults.publisher().subscribe(taskProcessor);

var instrumentedStream = new PublisherWithMetrics(timer, model);
var instrumentedStream = new PublisherWithMetrics(timer, model, request, localNodeId);
taskProcessor.subscribe(instrumentedStream);

var streamErrorHandler = streamErrorHandler(instrumentedStream);

listener.onResponse(new InferenceAction.Response(inferenceResults, streamErrorHandler));
} else {
recordMetrics(model, timer, null);
recordMetrics(model, timer, request, localNodeId, null);
listener.onResponse(new InferenceAction.Response(inferenceResults));
}
}, e -> {
recordMetrics(model, timer, e);
recordMetrics(model, timer, request, localNodeId, e);
listener.onFailure(e);
}));
}
Expand All @@ -299,9 +307,14 @@ protected Flow.Publisher<ChunkedToXContent> streamErrorHandler(Flow.Processor<Ch
return upstream;
}

private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) {
private void recordMetrics(Model model, InferenceTimer timer, Request request, String localNodeId, @Nullable Throwable t) {
try {
inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, unwrapCause(t)));
Map<String, Object> metricAttributes = new HashMap<>();
metricAttributes.putAll(modelAttributes(model));
metricAttributes.putAll(routingAttributes(request, localNodeId));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m a bit concerned that the metric cardinality will grow extremely rapidly as you add the node_id. I'm not 100% sure if this is a problem for Elasticsearch (overview cluster).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the attribute will have a high cardinality and not the metric, right?

I don't think that should be inherently a problem vs for example Prometheus, which creates a timeseries automatically for each unique metric/attribute pair. We'll only do (manual) aggregations on node id over a limited time window + the number of unique node ids in a (serverless) cluster shouldn't be very high for a timeframe of let's say 10 minutes - 1 day. I've also checked other Elasticsearch metrics and a lot of them include the node id as an attribute, so I guess, if it's not a problem for them it shouldn't be one for us.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summarizing a slack conversation:
In addition to the cardinality risk during search, there seems to be a risk for high cardinality when the metrics are pushed from the node, where each variation creates a new element that consumes capacity in an outbound queue. The queue has some max capacity and is flushed periodically.

We think (hope) the risk is relatively low, given that nodes have 1 id and the routing ids should have a handful of ids.

metricAttributes.putAll(responseAttributes(unwrapCause(t)));

inferenceStats.inferenceDuration().record(timer.elapsedMillis(), metricAttributes);
} catch (Exception e) {
log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics");
}
Expand Down Expand Up @@ -353,10 +366,14 @@ private class PublisherWithMetrics extends DelegatingProcessor<ChunkedToXContent

private final InferenceTimer timer;
private final Model model;
private final Request request;
private final String localNodeId;

private PublisherWithMetrics(InferenceTimer timer, Model model) {
private PublisherWithMetrics(InferenceTimer timer, Model model, Request request, String localNodeId) {
this.timer = timer;
this.model = model;
this.request = request;
this.localNodeId = localNodeId;
}

@Override
Expand All @@ -366,19 +383,19 @@ protected void next(ChunkedToXContent item) {

@Override
public void onError(Throwable throwable) {
recordMetrics(model, timer, throwable);
recordMetrics(model, timer, request, localNodeId, throwable);
super.onError(throwable);
}

@Override
protected void onCancel() {
recordMetrics(model, timer, null);
recordMetrics(model, timer, request, localNodeId, null);
super.onCancel();
}

@Override
public void onComplete() {
recordMetrics(model, timer, null);
recordMetrics(model, timer, request, localNodeId, null);
super.onComplete();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
import org.elasticsearch.telemetry.metric.LongCounter;
import org.elasticsearch.telemetry.metric.LongHistogram;
import org.elasticsearch.telemetry.metric.MeterRegistry;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;

import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static java.util.Map.entry;
import static java.util.stream.Stream.concat;

public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) {

Expand All @@ -45,49 +45,43 @@ public static InferenceStats create(MeterRegistry meterRegistry) {
);
}

public static Map<String, Object> modelAttributes(Model model) {
return toMap(modelAttributeEntries(model));
private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
}

private static Stream<Map.Entry<String, Object>> modelAttributeEntries(Model model) {
public static Map<String, Object> modelAttributes(Model model) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, why not just create a new HashMap and conditionally put the entries in it?

I think the stream API is useful when we want to declaratively process a collection (filter, transformation, find one element etc), but in this case converting back and forth seems like an overhead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was originally used as a generator/builder to construct a map from multiple different objects in multiple different functions, but now that this change is moving away from that, we can probably just use HashMap for the conditionals and Map.of where there aren't. We can use Collections.unmodifiableMap around the HashMap if we want to be safe and/or don't trust APM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to adapt it, I'll create a small follow-up PR tomorrow as the CI is green now and we want the metric attributes to appear pretty soon in EC Serverless and EC Hosted, so we can start on the integration tests.

var stream = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.getConfigurations().getService()))
.add(entry("task_type", model.getTaskType().toString()));
if (model.getServiceSettings().modelId() != null) {
stream.add(entry("model_id", model.getServiceSettings().modelId()));
}
return stream.build();
return toMap(stream.build());
}

private static Map<String, Object> toMap(Stream<Map.Entry<String, Object>> stream) {
return stream.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
public static Map<String, Object> routingAttributes(BaseInferenceActionRequest request, String nodeIdHandlingRequest) {
return Map.of("rerouted", request.hasBeenRerouted(), "node_id", nodeIdHandlingRequest);
}

public static Map<String, Object> responseAttributes(Model model, @Nullable Throwable t) {
return toMap(concat(modelAttributeEntries(model), errorAttributes(t)));
}

public static Map<String, Object> responseAttributes(UnparsedModel model, @Nullable Throwable t) {
public static Map<String, Object> modelAttributes(UnparsedModel model) {
var unknownModelAttributes = Stream.<Map.Entry<String, Object>>builder()
.add(entry("service", model.service()))
.add(entry("task_type", model.taskType().toString()))
.build();

return toMap(concat(unknownModelAttributes, errorAttributes(t)));
return toMap(unknownModelAttributes);
}

public static Map<String, Object> responseAttributes(@Nullable Throwable t) {
return toMap(errorAttributes(t));
}

private static Stream<Map.Entry<String, Object>> errorAttributes(@Nullable Throwable t) {
return switch (t) {
case null -> Stream.of(entry("status_code", 200));
var stream = switch (t) {
case null -> Stream.<Map.Entry<String, Object>>of(entry("status_code", 200));
case ElasticsearchStatusException ese -> Stream.<Map.Entry<String, Object>>builder()
.add(entry("status_code", ese.status().getStatus()))
.add(entry("error.type", String.valueOf(ese.status().getStatus())))
.build();
default -> Stream.of(entry("error.type", t.getClass().getSimpleName()));
default -> Stream.<Map.Entry<String, Object>>of(entry("error.type", t.getClass().getSimpleName()));
};

return toMap(stream);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ public abstract class BaseTransportInferenceActionTestCase<Request extends BaseI
protected static final String serviceId = "serviceId";
protected final TaskType taskType;
protected static final String inferenceId = "inferenceEntityId";
protected static final String localNodeId = "local-node-id";
protected InferenceServiceRegistry serviceRegistry;
protected InferenceStats inferenceStats;
protected InferenceServiceRateLimitCalculator inferenceServiceRateLimitCalculator;
Expand Down Expand Up @@ -100,6 +101,7 @@ public void setUp() throws Exception {
);

mockValidLicenseState();
mockNodeClient();
}

protected abstract BaseTransportInferenceAction<Request> createAction(
Expand Down Expand Up @@ -135,6 +137,8 @@ public void testMetricsAfterModelRegistryError() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
assertThat(attributes.get("rerouted"), nullValue());
assertThat(attributes.get("node_id"), nullValue());
}));
}

Expand Down Expand Up @@ -176,6 +180,8 @@ public void testMetricsAfterMissingService() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
assertThat(attributes.get("rerouted"), nullValue());
assertThat(attributes.get("node_id"), nullValue());
}));
}

Expand Down Expand Up @@ -216,6 +222,8 @@ public void testMetricsAfterUnknownTaskType() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus()));
assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus())));
assertThat(attributes.get("rerouted"), nullValue());
assertThat(attributes.get("node_id"), nullValue());
}));
}

Expand All @@ -232,6 +240,8 @@ public void testMetricsAfterInferError() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
assertThat(attributes.get("node_id"), is(localNodeId));
}));
}

Expand All @@ -254,6 +264,8 @@ public void testMetricsAfterStreamUnsupported() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(expectedStatus.getStatus()));
assertThat(attributes.get("error.type"), is(expectedError));
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
assertThat(attributes.get("node_id"), is(localNodeId));
}));
}

Expand All @@ -269,6 +281,8 @@ public void testMetricsAfterInferSuccess() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(200));
assertThat(attributes.get("error.type"), nullValue());
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
assertThat(attributes.get("node_id"), is(localNodeId));
}));
}

Expand All @@ -280,6 +294,8 @@ public void testMetricsAfterStreamInferSuccess() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(200));
assertThat(attributes.get("error.type"), nullValue());
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
assertThat(attributes.get("node_id"), is(localNodeId));
}));
}

Expand All @@ -296,6 +312,8 @@ public void testMetricsAfterStreamInferFailure() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), nullValue());
assertThat(attributes.get("error.type"), is(expectedError));
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
assertThat(attributes.get("node_id"), is(localNodeId));
}));
}

Expand Down Expand Up @@ -329,6 +347,8 @@ public void onComplete() {
assertThat(attributes.get("model_id"), nullValue());
assertThat(attributes.get("status_code"), is(200));
assertThat(attributes.get("error.type"), nullValue());
assertThat(attributes.get("rerouted"), is(Boolean.FALSE));
assertThat(attributes.get("node_id"), is(localNodeId));
}));
}

Expand Down Expand Up @@ -404,4 +424,8 @@ protected void mockModelAndServiceRegistry(InferenceService service) {
protected void mockValidLicenseState() {
when(licenseState.isAllowed(InferencePlugin.INFERENCE_API_FEATURE)).thenReturn(true);
}

private void mockNodeClient(){
when(nodeClient.getLocalNodeId()).thenReturn(localNodeId);
}
}
Loading