From 996d51538b027e7ca507caa1ad0ca6af5e814275 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Fri, 25 Apr 2025 09:48:15 -0400 Subject: [PATCH] [ML] Directly call Inference API from Proxy (#127342) In order to propagate response headers back from the proxied actions, we are directly calling the Transport actions via the NodeClient. --- .../action/TransportInferenceActionProxy.java | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java index 88a9927db2d9f..b18fa25b3bd69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -9,10 +9,15 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; @@ -30,7 +35,6 @@ import java.io.IOException; import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; public class TransportInferenceActionProxy extends HandledTransportAction { private final ModelRegistry modelRegistry; @@ -103,7 +107,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ); } - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener); + execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener); } catch (Exception e) { unifiedErrorFormatListener.onFailure(e); } @@ -122,6 +126,19 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); } - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); + execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); + } + + private void execute( + ActionType action, + Request request, + ActionListener listener + ) { + var threadContext = client.threadPool().getThreadContext(); + // stash the context so we clear the user's security headers, then restore and copy the response headers + var supplier = threadContext.newRestorableContext(true); + try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(INFERENCE_ORIGIN)) { + client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener)); + } } }