99
1010import org .elasticsearch .ElasticsearchStatusException ;
1111import org .elasticsearch .action .ActionListener ;
12+ import org .elasticsearch .action .ActionRequest ;
13+ import org .elasticsearch .action .ActionResponse ;
14+ import org .elasticsearch .action .ActionType ;
1215import org .elasticsearch .action .support .ActionFilters ;
16+ import org .elasticsearch .action .support .ContextPreservingActionListener ;
1317import org .elasticsearch .action .support .HandledTransportAction ;
1418import org .elasticsearch .client .internal .Client ;
1519import org .elasticsearch .common .util .concurrent .EsExecutors ;
20+ import org .elasticsearch .common .util .concurrent .ThreadContext ;
1621import org .elasticsearch .common .xcontent .XContentHelper ;
1722import org .elasticsearch .inference .TaskType ;
1823import org .elasticsearch .inference .UnparsedModel ;
3035import java .io .IOException ;
3136
3237import static org .elasticsearch .xpack .core .ClientHelper .INFERENCE_ORIGIN ;
33- import static org .elasticsearch .xpack .core .ClientHelper .executeAsyncWithOrigin ;
3438
3539public class TransportInferenceActionProxy extends HandledTransportAction <InferenceActionProxy .Request , InferenceAction .Response > {
3640 private final ModelRegistry modelRegistry ;
@@ -103,7 +107,7 @@ private void sendUnifiedCompletionRequest(InferenceActionProxy.Request request,
103107 );
104108 }
105109
106- executeAsyncWithOrigin ( client , INFERENCE_ORIGIN , UnifiedCompletionAction .INSTANCE , unifiedRequest , unifiedErrorFormatListener );
110+ execute ( UnifiedCompletionAction .INSTANCE , unifiedRequest , listener );
107111 } catch (Exception e ) {
108112 unifiedErrorFormatListener .onFailure (e );
109113 }
@@ -122,6 +126,19 @@ private void sendInferenceActionRequest(InferenceActionProxy.Request request, Ac
122126 inferenceActionRequestBuilder .setInferenceTimeout (request .getTimeout ()).setStream (request .isStreaming ());
123127 }
124128
125- executeAsyncWithOrigin (client , INFERENCE_ORIGIN , InferenceAction .INSTANCE , inferenceActionRequestBuilder .build (), listener );
129+ execute (InferenceAction .INSTANCE , inferenceActionRequestBuilder .build (), listener );
130+ }
131+
132+ private <Request extends ActionRequest , Response extends ActionResponse > void execute (
133+ ActionType <Response > action ,
134+ Request request ,
135+ ActionListener <Response > listener
136+ ) {
137+ var threadContext = client .threadPool ().getThreadContext ();
138+ // stash the context so we clear the user's security headers, then restore and copy the response headers
139+ var supplier = threadContext .newRestorableContext (true );
140+ try (ThreadContext .StoredContext ignore = threadContext .stashWithOrigin (INFERENCE_ORIGIN )) {
141+ client .execute (action , request , new ContextPreservingActionListener <>(supplier , listener ));
142+ }
126143 }
127144}
0 commit comments