Skip to content

Commit 19119ab

Browse files
authored
fix: Ensure authentication is required for all endpoints that require it and add an AuthInterceptor that can obtain credentials from a CredentialService (#292)
# Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](../CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests pass - [x] Appropriate READMEs were updated (if necessary) Fixes #275 and #274 🦕
1 parent 39f68f0 commit 19119ab

File tree

12 files changed

+586
-23
lines changed

12 files changed

+586
-23
lines changed

client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java

Lines changed: 91 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
import static io.a2a.util.Assert.checkNotNullParam;
88

99
import java.util.List;
10+
import java.util.Map;
1011
import java.util.function.Consumer;
1112
import java.util.stream.Collectors;
1213

1314
import io.a2a.client.transport.spi.interceptors.ClientCallContext;
1415
import io.a2a.client.transport.spi.ClientTransport;
16+
import io.a2a.client.transport.spi.interceptors.ClientCallInterceptor;
17+
import io.a2a.client.transport.spi.interceptors.PayloadAndHeaders;
18+
import io.a2a.client.transport.spi.interceptors.auth.AuthInterceptor;
1519
import io.a2a.common.A2AHeaders;
1620
import io.a2a.grpc.A2AServiceGrpc;
1721
import io.a2a.grpc.CancelTaskRequest;
@@ -32,11 +36,14 @@
3236
import io.a2a.spec.GetTaskPushNotificationConfigParams;
3337
import io.a2a.spec.ListTaskPushNotificationConfigParams;
3438
import io.a2a.spec.MessageSendParams;
39+
import io.a2a.spec.SendStreamingMessageRequest;
40+
import io.a2a.spec.SetTaskPushNotificationConfigRequest;
3541
import io.a2a.spec.StreamingEventKind;
3642
import io.a2a.spec.Task;
3743
import io.a2a.spec.TaskIdParams;
3844
import io.a2a.spec.TaskPushNotificationConfig;
3945
import io.a2a.spec.TaskQueryParams;
46+
import io.a2a.spec.TaskResubscriptionRequest;
4047
import io.grpc.Channel;
4148
import io.grpc.Metadata;
4249
import io.grpc.StatusRuntimeException;
@@ -45,25 +52,39 @@
4552

4653
public class GrpcTransport implements ClientTransport {
4754

55+
private static final Metadata.Key<String> AUTHORIZATION_METADATA_KEY = Metadata.Key.of(
56+
AuthInterceptor.AUTHORIZATION,
57+
Metadata.ASCII_STRING_MARSHALLER);
58+
private static final Metadata.Key<String> EXTENSIONS_KEY = Metadata.Key.of(
59+
A2AHeaders.X_A2A_EXTENSIONS,
60+
Metadata.ASCII_STRING_MARSHALLER);
4861
private final A2AServiceBlockingV2Stub blockingStub;
4962
private final A2AServiceStub asyncStub;
63+
private final List<ClientCallInterceptor> interceptors;
5064
private AgentCard agentCard;
5165

5266
public GrpcTransport(Channel channel, AgentCard agentCard) {
67+
this(channel, agentCard, null);
68+
}
69+
70+
public GrpcTransport(Channel channel, AgentCard agentCard, List<ClientCallInterceptor> interceptors) {
5371
checkNotNullParam("channel", channel);
5472
this.asyncStub = A2AServiceGrpc.newStub(channel);
5573
this.blockingStub = A2AServiceGrpc.newBlockingV2Stub(channel);
5674
this.agentCard = agentCard;
75+
this.interceptors = interceptors;
5776
}
5877

5978
@Override
6079
public EventKind sendMessage(MessageSendParams request, ClientCallContext context) throws A2AClientException {
6180
checkNotNullParam("request", request);
6281

6382
SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context);
83+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.SendMessageRequest.METHOD, sendMessageRequest,
84+
agentCard, context);
6485

6586
try {
66-
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
87+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
6788
SendMessageResponse response = stubWithMetadata.sendMessage(sendMessageRequest);
6889
if (response.hasMsg()) {
6990
return FromProto.message(response.getMsg());
@@ -83,10 +104,12 @@ public void sendMessageStreaming(MessageSendParams request, Consumer<StreamingEv
83104
checkNotNullParam("request", request);
84105
checkNotNullParam("eventConsumer", eventConsumer);
85106
SendMessageRequest grpcRequest = createGrpcSendMessageRequest(request, context);
107+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(SendStreamingMessageRequest.METHOD,
108+
grpcRequest, agentCard, context);
86109
StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);
87110

88111
try {
89-
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
112+
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders);
90113
stubWithMetadata.sendStreamingMessage(grpcRequest, streamObserver);
91114
} catch (StatusRuntimeException e) {
92115
throw GrpcErrorMapper.mapGrpcError(e, "Failed to send streaming message request: ");
@@ -103,9 +126,11 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A
103126
requestBuilder.setHistoryLength(request.historyLength());
104127
}
105128
GetTaskRequest getTaskRequest = requestBuilder.build();
129+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskRequest.METHOD, getTaskRequest,
130+
agentCard, context);
106131

107132
try {
108-
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
133+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
109134
return FromProto.task(stubWithMetadata.getTask(getTaskRequest));
110135
} catch (StatusRuntimeException e) {
111136
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task: ");
@@ -119,9 +144,11 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A
119144
CancelTaskRequest cancelTaskRequest = CancelTaskRequest.newBuilder()
120145
.setName("tasks/" + request.id())
121146
.build();
147+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.CancelTaskRequest.METHOD, cancelTaskRequest,
148+
agentCard, context);
122149

123150
try {
124-
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
151+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
125152
return FromProto.task(stubWithMetadata.cancelTask(cancelTaskRequest));
126153
} catch (StatusRuntimeException e) {
127154
throw GrpcErrorMapper.mapGrpcError(e, "Failed to cancel task: ");
@@ -139,9 +166,11 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN
139166
.setConfig(ToProto.taskPushNotificationConfig(request))
140167
.setConfigId(configId != null ? configId : request.taskId())
141168
.build();
169+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(SetTaskPushNotificationConfigRequest.METHOD,
170+
grpcRequest, agentCard, context);
142171

143172
try {
144-
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
173+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
145174
return FromProto.taskPushNotificationConfig(stubWithMetadata.createTaskPushNotificationConfig(grpcRequest));
146175
} catch (StatusRuntimeException e) {
147176
throw GrpcErrorMapper.mapGrpcError(e, "Failed to create task push notification config: ");
@@ -157,9 +186,11 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(
157186
GetTaskPushNotificationConfigRequest grpcRequest = GetTaskPushNotificationConfigRequest.newBuilder()
158187
.setName(getTaskPushNotificationConfigName(request))
159188
.build();
189+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.GetTaskPushNotificationConfigRequest.METHOD,
190+
grpcRequest, agentCard, context);
160191

161192
try {
162-
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
193+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
163194
return FromProto.taskPushNotificationConfig(stubWithMetadata.getTaskPushNotificationConfig(grpcRequest));
164195
} catch (StatusRuntimeException e) {
165196
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task push notification config: ");
@@ -175,9 +206,11 @@ public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(
175206
ListTaskPushNotificationConfigRequest grpcRequest = ListTaskPushNotificationConfigRequest.newBuilder()
176207
.setParent("tasks/" + request.id())
177208
.build();
209+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.ListTaskPushNotificationConfigRequest.METHOD,
210+
grpcRequest, agentCard, context);
178211

179212
try {
180-
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
213+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
181214
return stubWithMetadata.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream()
182215
.map(FromProto::taskPushNotificationConfig)
183216
.collect(Collectors.toList());
@@ -194,9 +227,11 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC
194227
DeleteTaskPushNotificationConfigRequest grpcRequest = DeleteTaskPushNotificationConfigRequest.newBuilder()
195228
.setName(getTaskPushNotificationConfigName(request.id(), request.pushNotificationConfigId()))
196229
.build();
230+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(io.a2a.spec.DeleteTaskPushNotificationConfigRequest.METHOD,
231+
grpcRequest, agentCard, context);
197232

198233
try {
199-
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
234+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context, payloadAndHeaders);
200235
stubWithMetadata.deleteTaskPushNotificationConfig(grpcRequest);
201236
} catch (StatusRuntimeException e) {
202237
throw GrpcErrorMapper.mapGrpcError(e, "Failed to delete task push notification config: ");
@@ -212,11 +247,13 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
212247
TaskSubscriptionRequest grpcRequest = TaskSubscriptionRequest.newBuilder()
213248
.setName("tasks/" + request.id())
214249
.build();
250+
PayloadAndHeaders payloadAndHeaders = applyInterceptors(TaskResubscriptionRequest.METHOD,
251+
grpcRequest, agentCard, context);
215252

216253
StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);
217254

218255
try {
219-
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
256+
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context, payloadAndHeaders);
220257
stubWithMetadata.taskSubscription(grpcRequest, streamObserver);
221258
} catch (StatusRuntimeException e) {
222259
throw GrpcErrorMapper.mapGrpcError(e, "Failed to resubscribe task push notification config: ");
@@ -249,43 +286,64 @@ private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messag
249286
* Creates gRPC metadata from ClientCallContext headers.
250287
* Extracts headers like X-A2A-Extensions and sets them as gRPC metadata.
251288
*/
252-
private Metadata createGrpcMetadata(ClientCallContext context) {
289+
private Metadata createGrpcMetadata(ClientCallContext context, PayloadAndHeaders payloadAndHeaders) {
253290
Metadata metadata = new Metadata();
254291

255292
if (context != null && context.getHeaders() != null) {
256293
// Set X-A2A-Extensions header if present
257294
String extensionsHeader = context.getHeaders().get(A2AHeaders.X_A2A_EXTENSIONS);
258295
if (extensionsHeader != null) {
259-
Metadata.Key<String> extensionsKey = Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER);
260-
metadata.put(extensionsKey, extensionsHeader);
296+
metadata.put(EXTENSIONS_KEY, extensionsHeader);
261297
}
262298

263299
// Add other headers as needed in the future
264300
// For now, we only handle X-A2A-Extensions
265301
}
302+
if (payloadAndHeaders != null && payloadAndHeaders.getHeaders() != null) {
303+
// Handle all headers from interceptors (including auth headers)
304+
for (Map.Entry<String, String> headerEntry : payloadAndHeaders.getHeaders().entrySet()) {
305+
String headerName = headerEntry.getKey();
306+
String headerValue = headerEntry.getValue();
307+
308+
if (headerValue != null) {
309+
// Use static key for common Authorization header, create dynamic keys for others
310+
if (AuthInterceptor.AUTHORIZATION.equals(headerName)) {
311+
metadata.put(AUTHORIZATION_METADATA_KEY, headerValue);
312+
} else {
313+
// Create a metadata key dynamically for API keys and other custom headers
314+
Metadata.Key<String> metadataKey = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER);
315+
metadata.put(metadataKey, headerValue);
316+
}
317+
}
318+
}
319+
}
266320

267321
return metadata;
268322
}
269323

270324
/**
271325
* Creates a blocking stub with metadata attached from the ClientCallContext.
272-
*
273-
* @param context the client call context
326+
*
327+
* @param context the client call context
328+
* @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
274329
* @return blocking stub with metadata interceptor
275330
*/
276-
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context) {
277-
Metadata metadata = createGrpcMetadata(context);
331+
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context,
332+
PayloadAndHeaders payloadAndHeaders) {
333+
Metadata metadata = createGrpcMetadata(context, payloadAndHeaders);
278334
return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
279335
}
280336

281337
/**
282338
* Creates an async stub with metadata attached from the ClientCallContext.
283-
*
284-
* @param context the client call context
339+
*
340+
* @param context the client call context
341+
* @param payloadAndHeaders the payloadAndHeaders after applying any interceptors
285342
* @return async stub with metadata interceptor
286343
*/
287-
private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context) {
288-
Metadata metadata = createGrpcMetadata(context);
344+
private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context,
345+
PayloadAndHeaders payloadAndHeaders) {
346+
Metadata metadata = createGrpcMetadata(context, payloadAndHeaders);
289347
return asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
290348
}
291349

@@ -307,4 +365,17 @@ private String getTaskPushNotificationConfigName(String taskId, String pushNotif
307365
return name.toString();
308366
}
309367

368+
private PayloadAndHeaders applyInterceptors(String methodName, Object payload,
369+
AgentCard agentCard, ClientCallContext clientCallContext) {
370+
PayloadAndHeaders payloadAndHeaders = new PayloadAndHeaders(payload,
371+
clientCallContext != null ? clientCallContext.getHeaders() : null);
372+
if (interceptors != null && ! interceptors.isEmpty()) {
373+
for (ClientCallInterceptor interceptor : interceptors) {
374+
payloadAndHeaders = interceptor.intercept(methodName, payloadAndHeaders.getPayload(),
375+
payloadAndHeaders.getHeaders(), agentCard, clientCallContext);
376+
}
377+
}
378+
return payloadAndHeaders;
379+
}
380+
310381
}

client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransportConfigBuilder.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ public GrpcTransportConfigBuilder channelFactory(Function<String, Channel> chann
2020

2121
@Override
2222
public GrpcTransportConfig build() {
23-
return new GrpcTransportConfig(channelFactory);
23+
GrpcTransportConfig config = new GrpcTransportConfig(channelFactory);
24+
config.setInterceptors(interceptors);
25+
return config;
2426
}
2527
}

client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransportProvider.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public GrpcTransport create(GrpcTransportConfig grpcTransportConfig, AgentCard a
1717

1818
Channel channel = grpcTransportConfig.getChannelFactory().apply(agentUrl);
1919
if (channel != null) {
20-
return new GrpcTransport(channel, agentCard);
20+
return new GrpcTransport(channel, agentCard, grpcTransportConfig.getInterceptors());
2121
}
2222

2323
throw new A2AClientException("Missing required GrpcTransportConfig");

client/transport/spi/pom.xml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121
<groupId>io.github.a2asdk</groupId>
2222
<artifactId>a2a-java-sdk-spec</artifactId>
2323
</dependency>
24+
<dependency>
25+
<groupId>org.junit.jupiter</groupId>
26+
<artifactId>junit-jupiter-api</artifactId>
27+
<scope>test</scope>
28+
</dependency>
29+
<dependency>
30+
<groupId>org.junit.jupiter</groupId>
31+
<artifactId>junit-jupiter-params</artifactId>
32+
<scope>test</scope>
33+
</dependency>
2434
</dependencies>
2535

2636
</project>

0 commit comments

Comments
 (0)