Skip to content

Commit 69471b1

Browse files
committed
We have to stash the security headers
1 parent 9c91f4c commit 69471b1

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,6 @@ public BaseTransportInferenceAction(
102102
NodeClient nodeClient,
103103
ThreadPool threadPool
104104
) {
105-
// TransportInferenceActionProxy depends on this action passing EsExecutors.DIRECT_EXECUTOR_SERVICE to preserve the headers.
106-
// If we change the ExecutorService, change the listener in TransportInferenceActionProxy to ContextPreservingActionListener.
107105
super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE);
108106
this.licenseState = licenseState;
109107
this.modelRegistry = modelRegistry;

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionProxy.java

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@
99

1010
import org.elasticsearch.ElasticsearchStatusException;
1111
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.ActionRequest;
13+
import org.elasticsearch.action.ActionResponse;
14+
import org.elasticsearch.action.ActionType;
1215
import org.elasticsearch.action.support.ActionFilters;
16+
import org.elasticsearch.action.support.ContextPreservingActionListener;
1317
import org.elasticsearch.action.support.HandledTransportAction;
1418
import org.elasticsearch.client.internal.Client;
1519
import org.elasticsearch.common.util.concurrent.EsExecutors;
20+
import org.elasticsearch.common.util.concurrent.ThreadContext;
1621
import org.elasticsearch.common.xcontent.XContentHelper;
1722
import org.elasticsearch.inference.TaskType;
1823
import org.elasticsearch.inference.UnparsedModel;
@@ -29,6 +34,8 @@
2934

3035
import java.io.IOException;
3136

37+
import static org.elasticsearch.xpack.core.ClientHelper.INFERENCE_ORIGIN;
38+
3239
public class TransportInferenceActionProxy extends HandledTransportAction<InferenceActionProxy.Request, InferenceAction.Response> {
3340
private final ModelRegistry modelRegistry;
3441
private final Client client;
@@ -100,9 +107,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request,
100107
);
101108
}
102109

103-
// TransportUnifiedCompletionInferenceAction currently runs on this thread with this thread context. If this changes,
104-
// change this listener to a ContextPreservingActionListener to preserve the response headers.
105-
client.execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener);
110+
execute(UnifiedCompletionAction.INSTANCE, unifiedRequest, listener);
106111
} catch (Exception e) {
107112
unifiedErrorFormatListener.onFailure(e);
108113
}
@@ -121,8 +126,19 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac
121126
inferenceActionRequestBuilder.setInferenceTimeout(request.getTimeout()).setStream(request.isStreaming());
122127
}
123128

124-
// TransportInferenceAction currently runs on this thread with this thread context. If this changes,
125-
// change this listener to a ContextPreservingActionListener to preserve the response headers.
126-
client.execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
129+
execute(InferenceAction.INSTANCE, inferenceActionRequestBuilder.build(), listener);
130+
}
131+
132+
private <Request extends ActionRequest, Response extends ActionResponse> void execute(
133+
ActionType<Response> action,
134+
Request request,
135+
ActionListener<Response> listener
136+
) {
137+
var threadContext = client.threadPool().getThreadContext();
138+
// stash the context so we clear the user's security headers, then restore and copy the response headers
139+
var supplier = threadContext.newRestorableContext(true);
140+
try (ThreadContext.StoredContext ignore = threadContext.stashWithOrigin(INFERENCE_ORIGIN)) {
141+
client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener));
142+
}
127143
}
128144
}

0 commit comments

Comments
 (0)