From 9c91f4c418367338aca9ace112a1c7863d9cce2d Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 24 Apr 2025 12:26:35 -0400 Subject: [PATCH 1/2] [ML] Directly call Inference API from Proxy In order to propagate response headers back from the proxied actions, we are directly calling the Transport actions via the NodeClient. --- .../action/BaseTransportInferenceAction.java | 2 ++ .../action/TransportInferenceActionProxy.java | 11 ++++++----- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 34d2ef0843cfc..4019e0e2cf62e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -102,6 +102,8 @@ public BaseTransportInferenceAction( NodeClient nodeClient, ThreadPool threadPool ) { + // TransportInferenceActionProxy depends on this action passing EsExecutors.DIRECT_EXECUTOR_SERVICE to preserve the headers. + // If we change the ExecutorService, change the listener in TransportInferenceActionProxy to ContextPreservingActionListener. super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.licenseState = licenseState; this.modelRegistry = modelRegistry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java index 88a9927db2d9f..0d84f2870c227 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -29,9 +29,6 @@ import java.io.IOException; -import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; -import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; - public class TransportInferenceActionProxy extends HandledTransportAction { private final ModelRegistry modelRegistry; private final Client client; @@ -103,7 +100,9 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ); } - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, UnifiedCompletionAction.INSTANCE, unifiedRequest, unifiedErrorFormatListener); + // TransportUnifiedCompletionInferenceAction currently runs on this thread with this thread context. If this changes, + // change this listener to a ContextPreservingActionListener to preserve the response headers. + client.execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener); } catch (Exception e) { unifiedErrorFormatListener.onFailure(e); } @@ -122,6 +121,8 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); } - executeAsyncWithOrigin(client, INFERENCE_ORIGIN, InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); + // TransportInferenceAction currently runs on this thread with this thread context. If this changes, + // change this listener to a ContextPreservingActionListener to preserve the response headers. + client.execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); } } From 69471b10303a755cb9e9fe56806725027f830113 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 24 Apr 2025 14:48:43 -0400 Subject: [PATCH 2/2] We have to stash the security headers --- .../action/BaseTransportInferenceAction.java | 2 -- .../action/TransportInferenceActionProxy.java | 28 +++++++++++++++---- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java index 4019e0e2cf62e..34d2ef0843cfc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -102,8 +102,6 @@ public BaseTransportInferenceAction( NodeClient nodeClient, ThreadPool threadPool ) { - // TransportInferenceActionProxy depends on this action passing EsExecutors.DIRECT_EXECUTOR_SERVICE to preserve the headers. - // If we change the ExecutorService, change the listener in TransportInferenceActionProxy to ContextPreservingActionListener. super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); this.licenseState = licenseState; this.modelRegistry = modelRegistry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java index 0d84f2870c227..b18fa25b3bd69 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java @@ -9,10 +9,15 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.client.internal.Client; import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; @@ -29,6 +34,8 @@ import java.io.IOException; +import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN; + public class TransportInferenceActionProxy extends HandledTransportAction { private final ModelRegistry modelRegistry; private final Client client; @@ -100,9 +107,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request, ); } - // TransportUnifiedCompletionInferenceAction currently runs on this thread with this thread context. If this changes, - // change this listener to a ContextPreservingActionListener to preserve the response headers. - client.execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener); + execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener); } catch (Exception e) { unifiedErrorFormatListener.onFailure(e); } @@ -121,8 +126,19 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming()); } - // TransportInferenceAction currently runs on this thread with this thread context. If this changes, - // change this listener to a ContextPreservingActionListener to preserve the response headers. - client.execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); + execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener); + } + + private void execute( + ActionType action, + Request request, + ActionListener listener + ) { + var threadContext = client.threadPool().getThreadContext(); + // stash the context so we clear the user's security headers, then restore and copy the response headers + var supplier = threadContext.newRestorableContext(true); + try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(INFERENCE_ORIGIN)) { + client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener)); + } } }