Skip to content

Commit ea56419

Browse files
authored
feat: Implement ServerCallContext (#206)
Port the ServerCallContext from Python
1 parent 65f4c93 commit ea56419

File tree

10 files changed

+284
-133
lines changed

10 files changed

+284
-133
lines changed

reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
44
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
55

6+
import java.util.HashMap;
7+
import java.util.Map;
8+
import java.util.Set;
69
import java.util.concurrent.Executor;
710
import java.util.concurrent.Flow;
811
import java.util.concurrent.atomic.AtomicLong;
@@ -18,6 +21,9 @@
1821
import com.fasterxml.jackson.core.io.JsonEOFException;
1922
import com.fasterxml.jackson.databind.JsonNode;
2023
import io.a2a.server.ExtendedAgentCard;
24+
import io.a2a.server.ServerCallContext;
25+
import io.a2a.server.auth.UnauthenticatedUser;
26+
import io.a2a.server.auth.User;
2127
import io.a2a.server.requesthandlers.JSONRPCHandler;
2228
import io.a2a.server.util.async.Internal;
2329
import io.a2a.spec.AgentCard;
@@ -78,9 +84,13 @@ public class A2AServerRoutes {
7884
@Internal
7985
Executor executor;
8086

87+
@Inject
88+
Instance<CallContextFactory> callContextFactory;
89+
8190
@Route(path = "/", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING)
8291
public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
8392
boolean streaming = false;
93+
ServerCallContext context = createCallContext(rc);
8494
JSONRPCResponse<?> nonStreamingResponse = null;
8595
Multi<? extends JSONRPCResponse<?>> streamingResponse = null;
8696
JSONRPCErrorResponse error = null;
@@ -89,10 +99,10 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
8999
if (isStreamingRequest(body)) {
90100
streaming = true;
91101
StreamingJSONRPCRequest<?> request = Utils.OBJECT_MAPPER.readValue(body, StreamingJSONRPCRequest.class);
92-
streamingResponse = processStreamingRequest(request);
102+
streamingResponse = processStreamingRequest(request, context);
93103
} else {
94104
NonStreamingJSONRPCRequest<?> request = Utils.OBJECT_MAPPER.readValue(body, NonStreamingJSONRPCRequest.class);
95-
nonStreamingResponse = processNonStreamingRequest(request);
105+
nonStreamingResponse = processNonStreamingRequest(request, context);
96106
}
97107
} catch (JsonProcessingException e) {
98108
error = handleError(e);
@@ -183,32 +193,34 @@ public void getAuthenticatedExtendedAgentCard(RoutingExchange re) {
183193
}
184194
}
185195

186-
private JSONRPCResponse<?> processNonStreamingRequest(NonStreamingJSONRPCRequest<?> request) {
187-
if (request instanceof GetTaskRequest) {
188-
return jsonRpcHandler.onGetTask((GetTaskRequest) request);
189-
} else if (request instanceof CancelTaskRequest) {
190-
return jsonRpcHandler.onCancelTask((CancelTaskRequest) request);
191-
} else if (request instanceof SetTaskPushNotificationConfigRequest) {
192-
return jsonRpcHandler.setPushNotificationConfig((SetTaskPushNotificationConfigRequest) request);
193-
} else if (request instanceof GetTaskPushNotificationConfigRequest) {
194-
return jsonRpcHandler.getPushNotificationConfig((GetTaskPushNotificationConfigRequest) request);
195-
} else if (request instanceof SendMessageRequest) {
196-
return jsonRpcHandler.onMessageSend((SendMessageRequest) request);
197-
} else if (request instanceof ListTaskPushNotificationConfigRequest) {
198-
return jsonRpcHandler.listPushNotificationConfig((ListTaskPushNotificationConfigRequest) request);
199-
} else if (request instanceof DeleteTaskPushNotificationConfigRequest) {
200-
return jsonRpcHandler.deletePushNotificationConfig((DeleteTaskPushNotificationConfigRequest) request);
196+
private JSONRPCResponse<?> processNonStreamingRequest(
197+
NonStreamingJSONRPCRequest<?> request, ServerCallContext context) {
198+
if (request instanceof GetTaskRequest req) {
199+
return jsonRpcHandler.onGetTask(req, context);
200+
} else if (request instanceof CancelTaskRequest req) {
201+
return jsonRpcHandler.onCancelTask(req, context);
202+
} else if (request instanceof SetTaskPushNotificationConfigRequest req) {
203+
return jsonRpcHandler.setPushNotificationConfig(req, context);
204+
} else if (request instanceof GetTaskPushNotificationConfigRequest req) {
205+
return jsonRpcHandler.getPushNotificationConfig(req, context);
206+
} else if (request instanceof SendMessageRequest req) {
207+
return jsonRpcHandler.onMessageSend(req, context);
208+
} else if (request instanceof ListTaskPushNotificationConfigRequest req) {
209+
return jsonRpcHandler.listPushNotificationConfig(req, context);
210+
} else if (request instanceof DeleteTaskPushNotificationConfigRequest req) {
211+
return jsonRpcHandler.deletePushNotificationConfig(req, context);
201212
} else {
202213
return generateErrorResponse(request, new UnsupportedOperationError());
203214
}
204215
}
205216

206-
private Multi<? extends JSONRPCResponse<?>> processStreamingRequest(JSONRPCRequest<?> request) {
217+
private Multi<? extends JSONRPCResponse<?>> processStreamingRequest(
218+
JSONRPCRequest<?> request, ServerCallContext context) {
207219
Flow.Publisher<? extends JSONRPCResponse<?>> publisher;
208-
if (request instanceof SendStreamingMessageRequest) {
209-
publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request);
210-
} else if (request instanceof TaskResubscriptionRequest) {
211-
publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request);
220+
if (request instanceof SendStreamingMessageRequest req) {
221+
publisher = jsonRpcHandler.onMessageSendStream(req, context);
222+
} else if (request instanceof TaskResubscriptionRequest req) {
223+
publisher = jsonRpcHandler.onResubscribeToTask(req, context);
212224
} else {
213225
return Multi.createFrom().item(generateErrorResponse(request, new UnsupportedOperationError()));
214226
}
@@ -234,6 +246,42 @@ static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) {
234246
streamingMultiSseSupportSubscribedRunnable = runnable;
235247
}
236248

249+
private ServerCallContext createCallContext(RoutingContext rc) {
250+
251+
if (callContextFactory.isUnsatisfied()) {
252+
User user;
253+
if (rc.user() == null) {
254+
user = UnauthenticatedUser.INSTANCE;
255+
} else {
256+
user = new User() {
257+
@Override
258+
public boolean isAuthenticated() {
259+
return rc.userContext().authenticated();
260+
}
261+
262+
@Override
263+
public String getUsername() {
264+
return rc.user().subject();
265+
}
266+
};
267+
}
268+
Map<String, Object> state = new HashMap<>();
269+
// TODO Python's impl has
270+
// state['auth'] = request.auth
271+
// in jsonrpc_app.py. Figure out what this maps to in what Vert.X gives us
272+
273+
Map<String, String> headers = new HashMap<>();
274+
Set<String> headerNames = rc.request().headers().names();
275+
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
276+
state.put("headers", headers);
277+
278+
return new ServerCallContext(user, state);
279+
} else {
280+
CallContextFactory builder = callContextFactory.get();
281+
return builder.build(rc);
282+
}
283+
}
284+
237285
// Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API
238286
private static class MultiSseSupport {
239287

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package io.a2a.server.apps.quarkus;
2+
3+
import io.a2a.server.ServerCallContext;
4+
import io.vertx.ext.web.RoutingContext;
5+
6+
public interface CallContextFactory {
7+
ServerCallContext build(RoutingContext rc);
8+
}
Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
package io.a2a.server;
22

3+
import java.util.Map;
4+
import java.util.concurrent.ConcurrentHashMap;
5+
6+
import io.a2a.server.auth.User;
7+
38
public class ServerCallContext {
4-
// TODO port the fields
9+
// TODO Not totally sure yet about these field types
10+
private final Map<Object, Object> modelConfig = new ConcurrentHashMap<>();
11+
private final Map<String, Object> state;
12+
private final User user;
13+
14+
public ServerCallContext(User user, Map<String, Object> state) {
15+
this.user = user;
16+
this.state = state;
17+
}
18+
19+
public Map<String, Object> getState() {
20+
return state;
21+
}
22+
23+
public User getUser() {
24+
return user;
25+
}
526
}

sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,21 @@ public class RequestContext {
2222
private String contextId;
2323
private Task task;
2424
private List<Task> relatedTasks;
25-
26-
public RequestContext(MessageSendParams params, String taskId, String contextId, Task task, List<Task> relatedTasks) throws InvalidParamsError {
25+
private final ServerCallContext callContext;
26+
27+
public RequestContext(
28+
MessageSendParams params,
29+
String taskId,
30+
String contextId,
31+
Task task,
32+
List<Task> relatedTasks,
33+
ServerCallContext callContext) throws InvalidParamsError {
2734
this.params = params;
2835
this.taskId = taskId;
2936
this.contextId = contextId;
3037
this.task = task;
3138
this.relatedTasks = relatedTasks == null ? new ArrayList<>() : relatedTasks;
39+
this.callContext = callContext;
3240

3341
// if the taskId and contextId were specified, they must match the params
3442
if (params != null) {
@@ -73,6 +81,10 @@ public MessageSendConfiguration getConfiguration() {
7381
return params != null ? params.configuration() : null;
7482
}
7583

84+
public ServerCallContext getCallContext() {
85+
return callContext;
86+
}
87+
7688
public String getUserInput(String delimiter) {
7789
if (params == null) {
7890
return "";
@@ -187,7 +199,7 @@ public ServerCallContext getServerCallContext() {
187199
}
188200

189201
public RequestContext build() {
190-
return new RequestContext(params, taskId, contextId, task, relatedTasks);
202+
return new RequestContext(params, taskId, contextId, task, relatedTasks, serverCallContext);
191203
}
192204
}
193205

sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package io.a2a.server.auth;
22

33
public class UnauthenticatedUser implements User {
4+
5+
public static UnauthenticatedUser INSTANCE = new UnauthenticatedUser();
6+
7+
private UnauthenticatedUser() {
8+
}
9+
410
@Override
511
public boolean isAuthenticated() {
612
return false;

sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import java.util.concurrent.atomic.AtomicReference;
1717
import java.util.function.Supplier;
1818

19+
import io.a2a.server.ServerCallContext;
1920
import io.a2a.server.agentexecution.AgentExecutor;
2021
import io.a2a.server.agentexecution.RequestContext;
2122
import io.a2a.server.agentexecution.SimpleRequestContextBuilder;
@@ -87,7 +88,7 @@ public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore,
8788
}
8889

8990
@Override
90-
public Task onGetTask(TaskQueryParams params) throws JSONRPCError {
91+
public Task onGetTask(TaskQueryParams params, ServerCallContext context) throws JSONRPCError {
9192
LOGGER.debug("onGetTask {}", params.id());
9293
Task task = taskStore.get(params.id());
9394
if (task == null) {
@@ -114,7 +115,7 @@ public Task onGetTask(TaskQueryParams params) throws JSONRPCError {
114115
}
115116

116117
@Override
117-
public Task onCancelTask(TaskIdParams params) throws JSONRPCError {
118+
public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws JSONRPCError {
118119
Task task = taskStore.get(params.id());
119120
if (task == null) {
120121
throw new TaskNotFoundError();
@@ -136,6 +137,7 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError {
136137
.setTaskId(task.getId())
137138
.setContextId(task.getContextId())
138139
.setTask(task)
140+
.setServerCallContext(context)
139141
.build(),
140142
queue);
141143

@@ -152,9 +154,9 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError {
152154
}
153155

154156
@Override
155-
public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError {
157+
public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws JSONRPCError {
156158
LOGGER.debug("onMessageSend - task: {}; context {}", params.message().getTaskId(), params.message().getContextId());
157-
MessageSendSetup mss = initMessageSend(params);
159+
MessageSendSetup mss = initMessageSend(params, context);
158160

159161
String taskId = mss.requestContext.getTaskId();
160162
LOGGER.debug("Request context taskId: {}", taskId);
@@ -200,9 +202,10 @@ public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError {
200202
}
201203

202204
@Override
203-
public Flow.Publisher<StreamingEventKind> onMessageSendStream(MessageSendParams params) throws JSONRPCError {
205+
public Flow.Publisher<StreamingEventKind> onMessageSendStream(
206+
MessageSendParams params, ServerCallContext context) throws JSONRPCError {
204207
LOGGER.debug("onMessageSendStream - task: {}; context {}", params.message().getTaskId(), params.message().getContextId());
205-
MessageSendSetup mss = initMessageSend(params);
208+
MessageSendSetup mss = initMessageSend(params, context);
206209

207210
AtomicReference<String> taskId = new AtomicReference<>(mss.requestContext.getTaskId());
208211
EventQueue queue = queueManager.createOrTap(taskId.get());
@@ -260,7 +263,8 @@ public Flow.Publisher<StreamingEventKind> onMessageSendStream(MessageSendParams
260263
}
261264

262265
@Override
263-
public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError {
266+
public TaskPushNotificationConfig onSetTaskPushNotificationConfig(
267+
TaskPushNotificationConfig params, ServerCallContext context) throws JSONRPCError {
264268
if (pushConfigStore == null) {
265269
throw new UnsupportedOperationError();
266270
}
@@ -275,7 +279,8 @@ public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotifi
275279
}
276280

277281
@Override
278-
public TaskPushNotificationConfig onGetTaskPushNotificationConfig(GetTaskPushNotificationConfigParams params) throws JSONRPCError {
282+
public TaskPushNotificationConfig onGetTaskPushNotificationConfig(
283+
GetTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError {
279284
if (pushConfigStore == null) {
280285
throw new UnsupportedOperationError();
281286
}
@@ -305,7 +310,8 @@ private PushNotificationConfig getPushNotificationConfig(List<PushNotificationCo
305310
}
306311

307312
@Override
308-
public Flow.Publisher<StreamingEventKind> onResubscribeToTask(TaskIdParams params) throws JSONRPCError {
313+
public Flow.Publisher<StreamingEventKind> onResubscribeToTask(
314+
TaskIdParams params, ServerCallContext context) throws JSONRPCError {
309315
Task task = taskStore.get(params.id());
310316
if (task == null) {
311317
throw new TaskNotFoundError();
@@ -325,7 +331,8 @@ public Flow.Publisher<StreamingEventKind> onResubscribeToTask(TaskIdParams param
325331
}
326332

327333
@Override
328-
public List<TaskPushNotificationConfig> onListTaskPushNotificationConfig(ListTaskPushNotificationConfigParams params) throws JSONRPCError {
334+
public List<TaskPushNotificationConfig> onListTaskPushNotificationConfig(
335+
ListTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError {
329336
if (pushConfigStore == null) {
330337
throw new UnsupportedOperationError();
331338
}
@@ -347,7 +354,8 @@ public List<TaskPushNotificationConfig> onListTaskPushNotificationConfig(ListTas
347354
}
348355

349356
@Override
350-
public void onDeleteTaskPushNotificationConfig(DeleteTaskPushNotificationConfigParams params) {
357+
public void onDeleteTaskPushNotificationConfig(
358+
DeleteTaskPushNotificationConfigParams params, ServerCallContext context) {
351359
if (pushConfigStore == null) {
352360
throw new UnsupportedOperationError();
353361
}
@@ -398,7 +406,7 @@ private void cleanupProducer(String taskId) {
398406
});
399407
}
400408

401-
private MessageSendSetup initMessageSend(MessageSendParams params) {
409+
private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallContext context) {
402410
TaskManager taskManager = new TaskManager(
403411
params.message().getTaskId(),
404412
params.message().getContextId(),
@@ -421,6 +429,7 @@ private MessageSendSetup initMessageSend(MessageSendParams params) {
421429
.setTaskId(task == null ? null : task.getId())
422430
.setContextId(params.message().getContextId())
423431
.setTask(task)
432+
.setServerCallContext(context)
424433
.build();
425434
return new MessageSendSetup(taskManager, task, requestContext);
426435
}

0 commit comments

Comments
 (0)