Skip to content

Commit 9b8bbb2

Browse files
authored
feat: ServerCallContext ports from Python implementation (#285)
Port server-side extension support and implement complete gRPC context access equivalent to Python's ServicerContext. Adds client/server metadata support, rich context information, and interceptor infrastructure while maintaining backward compatibility.
1 parent a201168 commit 9b8bbb2

File tree

17 files changed

+687
-41
lines changed

17 files changed

+687
-41
lines changed

client/transport/grpc/src/main/java/io/a2a/client/transport/grpc/GrpcTransport.java

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import io.a2a.client.transport.spi.interceptors.ClientCallContext;
1414
import io.a2a.client.transport.spi.ClientTransport;
15+
import io.a2a.common.A2AHeaders;
1516
import io.a2a.grpc.A2AServiceGrpc;
1617
import io.a2a.grpc.CancelTaskRequest;
1718
import io.a2a.grpc.CreateTaskPushNotificationConfigRequest;
@@ -37,8 +38,9 @@
3738
import io.a2a.spec.TaskPushNotificationConfig;
3839
import io.a2a.spec.TaskQueryParams;
3940
import io.grpc.Channel;
40-
41+
import io.grpc.Metadata;
4142
import io.grpc.StatusRuntimeException;
43+
import io.grpc.stub.MetadataUtils;
4244
import io.grpc.stub.StreamObserver;
4345

4446
public class GrpcTransport implements ClientTransport {
@@ -61,7 +63,8 @@ public EventKind sendMessage(MessageSendParams request, ClientCallContext contex
6163
SendMessageRequest sendMessageRequest = createGrpcSendMessageRequest(request, context);
6264

6365
try {
64-
SendMessageResponse response = blockingStub.sendMessage(sendMessageRequest);
66+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
67+
SendMessageResponse response = stubWithMetadata.sendMessage(sendMessageRequest);
6568
if (response.hasMsg()) {
6669
return FromProto.message(response.getMsg());
6770
} else if (response.hasTask()) {
@@ -83,7 +86,8 @@ public void sendMessageStreaming(MessageSendParams request, Consumer<StreamingEv
8386
StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);
8487

8588
try {
86-
asyncStub.sendStreamingMessage(grpcRequest, streamObserver);
89+
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
90+
stubWithMetadata.sendStreamingMessage(grpcRequest, streamObserver);
8791
} catch (StatusRuntimeException e) {
8892
throw GrpcErrorMapper.mapGrpcError(e, "Failed to send streaming message request: ");
8993
}
@@ -101,7 +105,8 @@ public Task getTask(TaskQueryParams request, ClientCallContext context) throws A
101105
GetTaskRequest getTaskRequest = requestBuilder.build();
102106

103107
try {
104-
return FromProto.task(blockingStub.getTask(getTaskRequest));
108+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
109+
return FromProto.task(stubWithMetadata.getTask(getTaskRequest));
105110
} catch (StatusRuntimeException e) {
106111
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task: ");
107112
}
@@ -116,7 +121,8 @@ public Task cancelTask(TaskIdParams request, ClientCallContext context) throws A
116121
.build();
117122

118123
try {
119-
return FromProto.task(blockingStub.cancelTask(cancelTaskRequest));
124+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
125+
return FromProto.task(stubWithMetadata.cancelTask(cancelTaskRequest));
120126
} catch (StatusRuntimeException e) {
121127
throw GrpcErrorMapper.mapGrpcError(e, "Failed to cancel task: ");
122128
}
@@ -135,7 +141,8 @@ public TaskPushNotificationConfig setTaskPushNotificationConfiguration(TaskPushN
135141
.build();
136142

137143
try {
138-
return FromProto.taskPushNotificationConfig(blockingStub.createTaskPushNotificationConfig(grpcRequest));
144+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
145+
return FromProto.taskPushNotificationConfig(stubWithMetadata.createTaskPushNotificationConfig(grpcRequest));
139146
} catch (StatusRuntimeException e) {
140147
throw GrpcErrorMapper.mapGrpcError(e, "Failed to create task push notification config: ");
141148
}
@@ -152,7 +159,8 @@ public TaskPushNotificationConfig getTaskPushNotificationConfiguration(
152159
.build();
153160

154161
try {
155-
return FromProto.taskPushNotificationConfig(blockingStub.getTaskPushNotificationConfig(grpcRequest));
162+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
163+
return FromProto.taskPushNotificationConfig(stubWithMetadata.getTaskPushNotificationConfig(grpcRequest));
156164
} catch (StatusRuntimeException e) {
157165
throw GrpcErrorMapper.mapGrpcError(e, "Failed to get task push notification config: ");
158166
}
@@ -169,7 +177,8 @@ public List<TaskPushNotificationConfig> listTaskPushNotificationConfigurations(
169177
.build();
170178

171179
try {
172-
return blockingStub.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream()
180+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
181+
return stubWithMetadata.listTaskPushNotificationConfig(grpcRequest).getConfigsList().stream()
173182
.map(FromProto::taskPushNotificationConfig)
174183
.collect(Collectors.toList());
175184
} catch (StatusRuntimeException e) {
@@ -187,7 +196,8 @@ public void deleteTaskPushNotificationConfigurations(DeleteTaskPushNotificationC
187196
.build();
188197

189198
try {
190-
blockingStub.deleteTaskPushNotificationConfig(grpcRequest);
199+
A2AServiceBlockingV2Stub stubWithMetadata = createBlockingStubWithMetadata(context);
200+
stubWithMetadata.deleteTaskPushNotificationConfig(grpcRequest);
191201
} catch (StatusRuntimeException e) {
192202
throw GrpcErrorMapper.mapGrpcError(e, "Failed to delete task push notification config: ");
193203
}
@@ -206,7 +216,8 @@ public void resubscribe(TaskIdParams request, Consumer<StreamingEventKind> event
206216
StreamObserver<StreamResponse> streamObserver = new EventStreamObserver(eventConsumer, errorConsumer);
207217

208218
try {
209-
asyncStub.taskSubscription(grpcRequest, streamObserver);
219+
A2AServiceStub stubWithMetadata = createAsyncStubWithMetadata(context);
220+
stubWithMetadata.taskSubscription(grpcRequest, streamObserver);
210221
} catch (StatusRuntimeException e) {
211222
throw GrpcErrorMapper.mapGrpcError(e, "Failed to resubscribe task push notification config: ");
212223
}
@@ -234,6 +245,50 @@ private SendMessageRequest createGrpcSendMessageRequest(MessageSendParams messag
234245
return builder.build();
235246
}
236247

248+
/**
249+
* Creates gRPC metadata from ClientCallContext headers.
250+
* Extracts headers like X-A2A-Extensions and sets them as gRPC metadata.
251+
*/
252+
private Metadata createGrpcMetadata(ClientCallContext context) {
253+
Metadata metadata = new Metadata();
254+
255+
if (context != null && context.getHeaders() != null) {
256+
// Set X-A2A-Extensions header if present
257+
String extensionsHeader = context.getHeaders().get(A2AHeaders.X_A2A_EXTENSIONS);
258+
if (extensionsHeader != null) {
259+
Metadata.Key<String> extensionsKey = Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER);
260+
metadata.put(extensionsKey, extensionsHeader);
261+
}
262+
263+
// Add other headers as needed in the future
264+
// For now, we only handle X-A2A-Extensions
265+
}
266+
267+
return metadata;
268+
}
269+
270+
/**
271+
* Creates a blocking stub with metadata attached from the ClientCallContext.
272+
*
273+
* @param context the client call context
274+
* @return blocking stub with metadata interceptor
275+
*/
276+
private A2AServiceBlockingV2Stub createBlockingStubWithMetadata(ClientCallContext context) {
277+
Metadata metadata = createGrpcMetadata(context);
278+
return blockingStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
279+
}
280+
281+
/**
282+
* Creates an async stub with metadata attached from the ClientCallContext.
283+
*
284+
* @param context the client call context
285+
* @return async stub with metadata interceptor
286+
*/
287+
private A2AServiceStub createAsyncStubWithMetadata(ClientCallContext context) {
288+
Metadata metadata = createGrpcMetadata(context);
289+
return asyncStub.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata));
290+
}
291+
237292
private String getTaskPushNotificationConfigName(GetTaskPushNotificationConfigParams params) {
238293
return getTaskPushNotificationConfigName(params.id(), params.pushNotificationConfigId());
239294
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package io.a2a.common;
2+
3+
/**
4+
* Common A2A protocol headers and constants.
5+
*/
6+
public final class A2AHeaders {
7+
8+
/**
9+
* HTTP header name for A2A extensions.
10+
* Used to communicate which extensions are requested by the client.
11+
*/
12+
public static final String X_A2A_EXTENSIONS = "X-A2A-Extensions";
13+
14+
private A2AHeaders() {
15+
// Utility class
16+
}
17+
}

reference/grpc/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
<groupId>${project.groupId}</groupId>
2020
<artifactId>a2a-java-sdk-reference-common</artifactId>
2121
</dependency>
22+
<dependency>
23+
<groupId>${project.groupId}</groupId>
24+
<artifactId>a2a-java-sdk-common</artifactId>
25+
</dependency>
2226
<dependency>
2327
<groupId>${project.groupId}</groupId>
2428
<artifactId>a2a-java-sdk-transport-grpc</artifactId>
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package io.a2a.server.grpc.quarkus;
2+
3+
import jakarta.enterprise.context.ApplicationScoped;
4+
import io.grpc.Context;
5+
import io.grpc.Contexts;
6+
import io.grpc.Metadata;
7+
import io.grpc.ServerCall;
8+
import io.grpc.ServerCallHandler;
9+
import io.grpc.ServerInterceptor;
10+
import io.a2a.common.A2AHeaders;
11+
import io.a2a.transport.grpc.context.GrpcContextKeys;
12+
13+
/**
14+
* gRPC server interceptor that captures request metadata and context information,
15+
* providing equivalent functionality to Python's grpc.aio.ServicerContext.
16+
*
17+
* This interceptor:
18+
* - Extracts A2A extension headers from incoming requests
19+
* - Captures ServerCall and Metadata for rich context access
20+
* - Stores context information in gRPC Context for service method access
21+
* - Provides proper equivalence to Python's ServicerContext
22+
*/
23+
@ApplicationScoped
24+
public class A2AExtensionsInterceptor implements ServerInterceptor {
25+
26+
27+
@Override
28+
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
29+
ServerCall<ReqT, RespT> serverCall,
30+
Metadata metadata,
31+
ServerCallHandler<ReqT, RespT> serverCallHandler) {
32+
33+
// Extract A2A extensions header
34+
Metadata.Key<String> extensionsKey =
35+
Metadata.Key.of(A2AHeaders.X_A2A_EXTENSIONS, Metadata.ASCII_STRING_MARSHALLER);
36+
String extensions = metadata.get(extensionsKey);
37+
38+
// Create enhanced context with rich information (equivalent to Python's ServicerContext)
39+
Context context = Context.current()
40+
// Store complete metadata for full header access
41+
.withValue(GrpcContextKeys.METADATA_KEY, metadata)
42+
// Store method name (equivalent to Python's context.method())
43+
.withValue(GrpcContextKeys.METHOD_NAME_KEY, serverCall.getMethodDescriptor().getFullMethodName())
44+
// Store peer information for client connection details
45+
.withValue(GrpcContextKeys.PEER_INFO_KEY, getPeerInfo(serverCall));
46+
47+
// Store A2A extensions if present
48+
if (extensions != null) {
49+
context = context.withValue(GrpcContextKeys.EXTENSIONS_HEADER_KEY, extensions);
50+
}
51+
52+
// Proceed with the call in the enhanced context
53+
return Contexts.interceptCall(context, serverCall, metadata, serverCallHandler);
54+
}
55+
56+
/**
57+
* Safely extracts peer information from the ServerCall.
58+
*
59+
* @param serverCall the gRPC ServerCall
60+
* @return peer information string, or "unknown" if not available
61+
*/
62+
private String getPeerInfo(ServerCall<?, ?> serverCall) {
63+
try {
64+
Object remoteAddr = serverCall.getAttributes().get(io.grpc.Grpc.TRANSPORT_ATTR_REMOTE_ADDR);
65+
return remoteAddr != null ? remoteAddr.toString() : "unknown";
66+
} catch (Exception e) {
67+
return "unknown";
68+
}
69+
}
70+
}

reference/grpc/src/main/java/io/a2a/server/grpc/quarkus/QuarkusGrpcHandler.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
import io.a2a.transport.grpc.handler.CallContextFactory;
1010
import io.a2a.transport.grpc.handler.GrpcHandler;
1111
import io.quarkus.grpc.GrpcService;
12+
import io.quarkus.grpc.RegisterInterceptor;
1213

1314
@GrpcService
15+
@RegisterInterceptor(A2AExtensionsInterceptor.class)
1416
public class QuarkusGrpcHandler extends GrpcHandler {
1517

1618
private final AgentCard agentCard;

reference/jsonrpc/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
55

66
import java.util.HashMap;
7+
import java.util.List;
78
import java.util.Map;
89
import java.util.Set;
910
import java.util.concurrent.Executor;
@@ -19,9 +20,11 @@
1920
import com.fasterxml.jackson.core.JsonProcessingException;
2021
import com.fasterxml.jackson.core.io.JsonEOFException;
2122
import com.fasterxml.jackson.databind.JsonNode;
23+
import io.a2a.common.A2AHeaders;
2224
import io.a2a.server.ServerCallContext;
2325
import io.a2a.server.auth.UnauthenticatedUser;
2426
import io.a2a.server.auth.User;
27+
import io.a2a.server.extensions.A2AExtensions;
2528
import io.a2a.server.util.async.Internal;
2629
import io.a2a.spec.AgentCard;
2730
import io.a2a.spec.CancelTaskRequest;
@@ -241,7 +244,11 @@ public String getUsername() {
241244
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
242245
state.put("headers", headers);
243246

244-
return new ServerCallContext(user, state);
247+
// Extract requested extensions from X-A2A-Extensions header
248+
List<String> extensionHeaderValues = rc.request().headers().getAll(A2AHeaders.X_A2A_EXTENSIONS);
249+
Set<String> requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues);
250+
251+
return new ServerCallContext(user, state, requestedExtensions);
245252
} else {
246253
CallContextFactory builder = callContextFactory.get();
247254
return builder.build(rc);

reference/rest/src/main/java/io/a2a/server/rest/quarkus/A2AServerRoutes.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import jakarta.inject.Inject;
1313
import jakarta.inject.Singleton;
1414

15+
import io.a2a.common.A2AHeaders;
1516
import io.a2a.server.ServerCallContext;
1617
import io.a2a.server.auth.UnauthenticatedUser;
1718
import io.a2a.server.auth.User;
@@ -34,9 +35,13 @@
3435
import io.vertx.core.http.HttpServerResponse;
3536
import io.vertx.ext.web.RoutingContext;
3637
import java.util.HashMap;
38+
import java.util.HashSet;
39+
import java.util.List;
3740
import java.util.Map;
3841
import java.util.Set;
3942

43+
import io.a2a.server.extensions.A2AExtensions;
44+
4045
@Singleton
4146
public class A2AServerRoutes {
4247

@@ -308,7 +313,11 @@ public String getUsername() {
308313
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
309314
state.put("headers", headers);
310315

311-
return new ServerCallContext(user, state);
316+
// Extract requested extensions from X-A2A-Extensions header
317+
List<String> extensionHeaderValues = rc.request().headers().getAll(A2AHeaders.X_A2A_EXTENSIONS);
318+
Set<String> requestedExtensions = A2AExtensions.getRequestedExtensions(extensionHeaderValues);
319+
320+
return new ServerCallContext(user, state, requestedExtensions);
312321
} else {
313322
CallContextFactory builder = callContextFactory.get();
314323
return builder.build(rc);

server-common/src/main/java/io/a2a/server/ServerCallContext.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package io.a2a.server;
22

3+
import java.util.HashSet;
34
import java.util.Map;
5+
import java.util.Set;
46
import java.util.concurrent.ConcurrentHashMap;
57

68
import io.a2a.server.auth.User;
@@ -10,10 +12,14 @@ public class ServerCallContext {
1012
private final Map<Object, Object> modelConfig = new ConcurrentHashMap<>();
1113
private final Map<String, Object> state;
1214
private final User user;
15+
private final Set<String> requestedExtensions;
16+
private final Set<String> activatedExtensions;
1317

14-
public ServerCallContext(User user, Map<String, Object> state) {
18+
public ServerCallContext(User user, Map<String, Object> state, Set<String> requestedExtensions) {
1519
this.user = user;
1620
this.state = state;
21+
this.requestedExtensions = new HashSet<>(requestedExtensions);
22+
this.activatedExtensions = new HashSet<>(); // Always starts empty, populated later by application code
1723
}
1824

1925
public Map<String, Object> getState() {
@@ -23,4 +29,28 @@ public Map<String, Object> getState() {
2329
public User getUser() {
2430
return user;
2531
}
32+
33+
public Set<String> getRequestedExtensions() {
34+
return new HashSet<>(requestedExtensions);
35+
}
36+
37+
public Set<String> getActivatedExtensions() {
38+
return new HashSet<>(activatedExtensions);
39+
}
40+
41+
public void activateExtension(String extensionUri) {
42+
activatedExtensions.add(extensionUri);
43+
}
44+
45+
public void deactivateExtension(String extensionUri) {
46+
activatedExtensions.remove(extensionUri);
47+
}
48+
49+
public boolean isExtensionActivated(String extensionUri) {
50+
return activatedExtensions.contains(extensionUri);
51+
}
52+
53+
public boolean isExtensionRequested(String extensionUri) {
54+
return requestedExtensions.contains(extensionUri);
55+
}
2656
}

0 commit comments

Comments
 (0)