77import static io .a2a .util .Assert .checkNotNullParam ;
88
99import java .util .List ;
10+ import java .util .Map ;
1011import java .util .function .Consumer ;
1112import java .util .stream .Collectors ;
1213
1314import io .a2a .client .transport .spi .interceptors .ClientCallContext ;
1415import io .a2a .client .transport .spi .ClientTransport ;
16+ import io .a2a .client .transport .spi .interceptors .ClientCallInterceptor ;
17+ import io .a2a .client .transport .spi .interceptors .PayloadAndHeaders ;
18+ import io .a2a .client .transport .spi .interceptors .auth .AuthInterceptor ;
1519import io .a2a .common .A2AHeaders ;
1620import io .a2a .grpc .A2AServiceGrpc ;
1721import io .a2a .grpc .CancelTaskRequest ;
3236import io .a2a .spec .GetTaskPushNotificationConfigParams ;
3337import io .a2a .spec .ListTaskPushNotificationConfigParams ;
3438import io .a2a .spec .MessageSendParams ;
39+ import io .a2a .spec .SendStreamingMessageRequest ;
40+ import io .a2a .spec .SetTaskPushNotificationConfigRequest ;
3541import io .a2a .spec .StreamingEventKind ;
3642import io .a2a .spec .Task ;
3743import io .a2a .spec .TaskIdParams ;
3844import io .a2a .spec .TaskPushNotificationConfig ;
3945import io .a2a .spec .TaskQueryParams ;
46+ import io .a2a .spec .TaskResubscriptionRequest ;
4047import io .grpc .Channel ;
4148import io .grpc .Metadata ;
4249import io .grpc .StatusRuntimeException ;
4552
4653public class GrpcTransport implements ClientTransport {
4754
55+ private static final Metadata .Key <String > AUTHORIZATION_METADATA_KEY = Metadata .Key .of (
56+ AuthInterceptor .AUTHORIZATION ,
57+ Metadata .ASCII_STRING_MARSHALLER );
58+ private static final Metadata .Key <String > EXTENSIONS_KEY = Metadata .Key .of (
59+ A2AHeaders .X_A2A_EXTENSIONS ,
60+ Metadata .ASCII_STRING_MARSHALLER );
4861 private final A2AServiceBlockingV2Stub blockingStub ;
4962 private final A2AServiceStub asyncStub ;
63+ private final List <ClientCallInterceptor > interceptors ;
5064 private AgentCard agentCard ;
5165
5266 public GrpcTransport (Channel channel , AgentCard agentCard ) {
67+ this (channel , agentCard , null );
68+ }
69+
70+ public GrpcTransport (Channel channel , AgentCard agentCard , List <ClientCallInterceptor > interceptors ) {
5371 checkNotNullParam ("channel" , channel );
5472 this .asyncStub = A2AServiceGrpc .newStub (channel );
5573 this .blockingStub = A2AServiceGrpc .newBlockingV2Stub (channel );
5674 this .agentCard = agentCard ;
75+ this .interceptors = interceptors ;
5776 }
5877
5978 @ Override
6079 public EventKind sendMessage (MessageSendParams request , ClientCallContext context ) throws A2AClientException {
6180 checkNotNullParam ("request" , request );
6281
6382 SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest (request , context );
83+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (io .a2a .spec .SendMessageRequest .METHOD , sendMessageRequest ,
84+ agentCard , context );
6485
6586 try {
66- A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context );
87+ A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context , payloadAndHeaders );
6788 SendMessageResponse response = stubWithMetadata .sendMessage (sendMessageRequest );
6889 if (response .hasMsg ()) {
6990 return FromProto .message (response .getMsg ());
@@ -83,10 +104,12 @@ public void sendMessageStreaming(MessageSendParams request, Consumer<StreamingEv
83104 checkNotNullParam ("request" , request );
84105 checkNotNullParam ("eventConsumer" , eventConsumer );
85106 SendMessageRequest grpcRequest = createGrpcSendMessageRequest (request , context );
107+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (SendStreamingMessageRequest .METHOD ,
108+ grpcRequest , agentCard , context );
86109 StreamObserver <StreamResponse > streamObserver = new EventStreamObserver (eventConsumer , errorConsumer );
87110
88111 try {
89- A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata (context );
112+ A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata (context , payloadAndHeaders );
90113 stubWithMetadata .sendStreamingMessage (grpcRequest , streamObserver );
91114 } catch (StatusRuntimeException e ) {
92115 throw GrpcErrorMapper .mapGrpcError (e , "Failed to send streaming message request: " );
@@ -103,9 +126,11 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A
103126 requestBuilder .setHistoryLength (request .historyLength ());
104127 }
105128 GetTaskRequest getTaskRequest = requestBuilder .build ();
129+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (io .a2a .spec .GetTaskRequest .METHOD , getTaskRequest ,
130+ agentCard , context );
106131
107132 try {
108- A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context );
133+ A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context , payloadAndHeaders );
109134 return FromProto .task (stubWithMetadata .getTask (getTaskRequest ));
110135 } catch (StatusRuntimeException e ) {
111136 throw GrpcErrorMapper .mapGrpcError (e , "Failed to get task: " );
@@ -119,9 +144,11 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A
119144 CancelTaskRequest cancelTaskRequest = CancelTaskRequest .newBuilder ()
120145 .setName ("tasks/" + request .id ())
121146 .build ();
147+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (io .a2a .spec .CancelTaskRequest .METHOD , cancelTaskRequest ,
148+ agentCard , context );
122149
123150 try {
124- A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context );
151+ A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context , payloadAndHeaders );
125152 return FromProto .task (stubWithMetadata .cancelTask (cancelTaskRequest ));
126153 } catch (StatusRuntimeException e ) {
127154 throw GrpcErrorMapper .mapGrpcError (e , "Failed to cancel task: " );
@@ -139,9 +166,11 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN
139166 .setConfig (ToProto .taskPushNotificationConfig (request ))
140167 .setConfigId (configId != null ? configId : request .taskId ())
141168 .build ();
169+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (SetTaskPushNotificationConfigRequest .METHOD ,
170+ grpcRequest , agentCard , context );
142171
143172 try {
144- A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context );
173+ A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context , payloadAndHeaders );
145174 return FromProto .taskPushNotificationConfig (stubWithMetadata .createTaskPushNotificationConfig (grpcRequest ));
146175 } catch (StatusRuntimeException e ) {
147176 throw GrpcErrorMapper .mapGrpcError (e , "Failed to create task push notification config: " );
@@ -157,9 +186,11 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(
157186 GetTaskPushNotificationConfigRequest grpcRequest = GetTaskPushNotificationConfigRequest .newBuilder ()
158187 .setName (getTaskPushNotificationConfigName (request ))
159188 .build ();
189+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (io .a2a .spec .GetTaskPushNotificationConfigRequest .METHOD ,
190+ grpcRequest , agentCard , context );
160191
161192 try {
162- A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context );
193+ A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context , payloadAndHeaders );
163194 return FromProto .taskPushNotificationConfig (stubWithMetadata .getTaskPushNotificationConfig (grpcRequest ));
164195 } catch (StatusRuntimeException e ) {
165196 throw GrpcErrorMapper .mapGrpcError (e , "Failed to get task push notification config: " );
@@ -175,9 +206,11 @@ public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(
175206 ListTaskPushNotificationConfigRequest grpcRequest = ListTaskPushNotificationConfigRequest .newBuilder ()
176207 .setParent ("tasks/" + request .id ())
177208 .build ();
209+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (io .a2a .spec .ListTaskPushNotificationConfigRequest .METHOD ,
210+ grpcRequest , agentCard , context );
178211
179212 try {
180- A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context );
213+ A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context , payloadAndHeaders );
181214 return stubWithMetadata .listTaskPushNotificationConfig (grpcRequest ).getConfigsList ().stream ()
182215 .map (FromProto ::taskPushNotificationConfig )
183216 .collect (Collectors .toList ());
@@ -194,9 +227,11 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC
194227 DeleteTaskPushNotificationConfigRequest grpcRequest = DeleteTaskPushNotificationConfigRequest .newBuilder ()
195228 .setName (getTaskPushNotificationConfigName (request .id (), request .pushNotificationConfigId ()))
196229 .build ();
230+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (io .a2a .spec .DeleteTaskPushNotificationConfigRequest .METHOD ,
231+ grpcRequest , agentCard , context );
197232
198233 try {
199- A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context );
234+ A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata (context , payloadAndHeaders );
200235 stubWithMetadata .deleteTaskPushNotificationConfig (grpcRequest );
201236 } catch (StatusRuntimeException e ) {
202237 throw GrpcErrorMapper .mapGrpcError (e , "Failed to delete task push notification config: " );
@@ -212,11 +247,13 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
212247 TaskSubscriptionRequest grpcRequest = TaskSubscriptionRequest .newBuilder ()
213248 .setName ("tasks/" + request .id ())
214249 .build ();
250+ PayloadAndHeaders payloadAndHeaders = applyInterceptors (TaskResubscriptionRequest .METHOD ,
251+ grpcRequest , agentCard , context );
215252
216253 StreamObserver <StreamResponse > streamObserver = new EventStreamObserver (eventConsumer , errorConsumer );
217254
218255 try {
219- A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata (context );
256+ A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata (context , payloadAndHeaders );
220257 stubWithMetadata .taskSubscription (grpcRequest , streamObserver );
221258 } catch (StatusRuntimeException e ) {
222259 throw GrpcErrorMapper .mapGrpcError (e , "Failed to resubscribe task push notification config: " );
@@ -249,43 +286,64 @@ private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messag
249286 * Creates gRPC metadata from ClientCallContext headers.
250287 * Extracts headers like X-A2A-Extensions and sets them as gRPC metadata.
251288 */
252- private Metadata createGrpcMetadata (ClientCallContext context ) {
289+ private Metadata createGrpcMetadata (ClientCallContext context , PayloadAndHeaders payloadAndHeaders ) {
253290 Metadata metadata = new Metadata ();
254291
255292 if (context != null && context .getHeaders () != null ) {
256293 // Set X-A2A-Extensions header if present
257294 String extensionsHeader = context .getHeaders ().get (A2AHeaders .X_A2A_EXTENSIONS );
258295 if (extensionsHeader != null ) {
259- Metadata .Key <String > extensionsKey = Metadata .Key .of (A2AHeaders .X_A2A_EXTENSIONS , Metadata .ASCII_STRING_MARSHALLER );
260- metadata .put (extensionsKey , extensionsHeader );
296+ metadata .put (EXTENSIONS_KEY , extensionsHeader );
261297 }
262298
263299 // Add other headers as needed in the future
264300 // For now, we only handle X-A2A-Extensions
265301 }
302+ if (payloadAndHeaders != null && payloadAndHeaders .getHeaders () != null ) {
303+ // Handle all headers from interceptors (including auth headers)
304+ for (Map .Entry <String , String > headerEntry : payloadAndHeaders .getHeaders ().entrySet ()) {
305+ String headerName = headerEntry .getKey ();
306+ String headerValue = headerEntry .getValue ();
307+
308+ if (headerValue != null ) {
309+ // Use static key for common Authorization header, create dynamic keys for others
310+ if (AuthInterceptor .AUTHORIZATION .equals (headerName )) {
311+ metadata .put (AUTHORIZATION_METADATA_KEY , headerValue );
312+ } else {
313+ // Create a metadata key dynamically for API keys and other custom headers
314+ Metadata .Key <String > metadataKey = Metadata .Key .of (headerName , Metadata .ASCII_STRING_MARSHALLER );
315+ metadata .put (metadataKey , headerValue );
316+ }
317+ }
318+ }
319+ }
266320
267321 return metadata ;
268322 }
269323
270324 /**
271325 * Creates a blocking stub with metadata attached from the ClientCallContext.
272- *
273- * @param context the client call context
326+ *
327+ * @param context the client call context
328+ * @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
274329 * @return blocking stub with metadata interceptor
275330 */
276- private A2AServiceBlockingV2Stub createBlockingStubWithMetadata (ClientCallContext context ) {
277- Metadata metadata = createGrpcMetadata (context );
331+ private A2AServiceBlockingV2Stub createBlockingStubWithMetadata (ClientCallContext context ,
332+ PayloadAndHeaders payloadAndHeaders ) {
333+ Metadata metadata = createGrpcMetadata (context , payloadAndHeaders );
278334 return blockingStub .withInterceptors (MetadataUtils .newAttachHeadersInterceptor (metadata ));
279335 }
280336
281337 /**
282338 * Creates an async stub with metadata attached from the ClientCallContext.
283- *
284- * @param context the client call context
339+ *
340+ * @param context the client call context
341+ * @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
285342 * @return async stub with metadata interceptor
286343 */
287- private A2AServiceStub createAsyncStubWithMetadata (ClientCallContext context ) {
288- Metadata metadata = createGrpcMetadata (context );
344+ private A2AServiceStub createAsyncStubWithMetadata (ClientCallContext context ,
345+ PayloadAndHeaders payloadAndHeaders ) {
346+ Metadata metadata = createGrpcMetadata (context , payloadAndHeaders );
289347 return asyncStub .withInterceptors (MetadataUtils .newAttachHeadersInterceptor (metadata ));
290348 }
291349
@@ -307,4 +365,17 @@ private String getTaskPushNotificationConfigName(String taskId, String pushNotif
307365 return name .toString ();
308366 }
309367
368+ private PayloadAndHeaders applyInterceptors (String methodName , Object payload ,
369+ AgentCard agentCard , ClientCallContext clientCallContext ) {
370+ PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders (payload ,
371+ clientCallContext != null ? clientCallContext .getHeaders () : null );
372+ if (interceptors != null && ! interceptors .isEmpty ()) {
373+ for (ClientCallInterceptor interceptor : interceptors ) {
374+ payloadAndHeaders = interceptor .intercept (methodName , payloadAndHeaders .getPayload (),
375+ payloadAndHeaders .getHeaders (), agentCard , clientCallContext );
376+ }
377+ }
378+ return payloadAndHeaders ;
379+ }
380+
310381}
0 commit comments