Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicLong;
Expand All @@ -18,6 +21,9 @@
import com.fasterxml.jackson.core.io.JsonEOFException;
import com.fasterxml.jackson.databind.JsonNode;
import io.a2a.server.ExtendedAgentCard;
import io.a2a.server.ServerCallContext;
import io.a2a.server.auth.UnauthenticatedUser;
import io.a2a.server.auth.User;
import io.a2a.server.requesthandlers.JSONRPCHandler;
import io.a2a.server.util.async.Internal;
import io.a2a.spec.AgentCard;
Expand Down Expand Up @@ -78,9 +84,13 @@ public class A2AServerRoutes {
@Internal
Executor executor;

@Inject
Instance<CallContextFactory> callContextFactory;

@Route(path = "/", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING)
public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
boolean streaming = false;
ServerCallContext context = createCallContext(rc);
JSONRPCResponse<?> nonStreamingResponse = null;
Multi<? extends JSONRPCResponse<?>> streamingResponse = null;
JSONRPCErrorResponse error = null;
Expand All @@ -89,10 +99,10 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) {
if (isStreamingRequest(body)) {
streaming = true;
StreamingJSONRPCRequest<?> request = Utils.OBJECT_MAPPER.readValue(body, StreamingJSONRPCRequest.class);
streamingResponse = processStreamingRequest(request);
streamingResponse = processStreamingRequest(request, context);
} else {
NonStreamingJSONRPCRequest<?> request = Utils.OBJECT_MAPPER.readValue(body, NonStreamingJSONRPCRequest.class);
nonStreamingResponse = processNonStreamingRequest(request);
nonStreamingResponse = processNonStreamingRequest(request, context);
}
} catch (JsonProcessingException e) {
error = handleError(e);
Expand Down Expand Up @@ -183,32 +193,34 @@ public void getAuthenticatedExtendedAgentCard(RoutingExchange re) {
}
}

private JSONRPCResponse<?> processNonStreamingRequest(NonStreamingJSONRPCRequest<?> request) {
if (request instanceof GetTaskRequest) {
return jsonRpcHandler.onGetTask((GetTaskRequest) request);
} else if (request instanceof CancelTaskRequest) {
return jsonRpcHandler.onCancelTask((CancelTaskRequest) request);
} else if (request instanceof SetTaskPushNotificationConfigRequest) {
return jsonRpcHandler.setPushNotificationConfig((SetTaskPushNotificationConfigRequest) request);
} else if (request instanceof GetTaskPushNotificationConfigRequest) {
return jsonRpcHandler.getPushNotificationConfig((GetTaskPushNotificationConfigRequest) request);
} else if (request instanceof SendMessageRequest) {
return jsonRpcHandler.onMessageSend((SendMessageRequest) request);
} else if (request instanceof ListTaskPushNotificationConfigRequest) {
return jsonRpcHandler.listPushNotificationConfig((ListTaskPushNotificationConfigRequest) request);
} else if (request instanceof DeleteTaskPushNotificationConfigRequest) {
return jsonRpcHandler.deletePushNotificationConfig((DeleteTaskPushNotificationConfigRequest) request);
private JSONRPCResponse<?> processNonStreamingRequest(
NonStreamingJSONRPCRequest<?> request, ServerCallContext context) {
if (request instanceof GetTaskRequest req) {
return jsonRpcHandler.onGetTask(req, context);
} else if (request instanceof CancelTaskRequest req) {
return jsonRpcHandler.onCancelTask(req, context);
} else if (request instanceof SetTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.setPushNotificationConfig(req, context);
} else if (request instanceof GetTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.getPushNotificationConfig(req, context);
} else if (request instanceof SendMessageRequest req) {
return jsonRpcHandler.onMessageSend(req, context);
} else if (request instanceof ListTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.listPushNotificationConfig(req, context);
} else if (request instanceof DeleteTaskPushNotificationConfigRequest req) {
return jsonRpcHandler.deletePushNotificationConfig(req, context);
} else {
return generateErrorResponse(request, new UnsupportedOperationError());
}
}

private Multi<? extends JSONRPCResponse<?>> processStreamingRequest(JSONRPCRequest<?> request) {
private Multi<? extends JSONRPCResponse<?>> processStreamingRequest(
JSONRPCRequest<?> request, ServerCallContext context) {
Flow.Publisher<? extends JSONRPCResponse<?>> publisher;
if (request instanceof SendStreamingMessageRequest) {
publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request);
} else if (request instanceof TaskResubscriptionRequest) {
publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request);
if (request instanceof SendStreamingMessageRequest req) {
publisher = jsonRpcHandler.onMessageSendStream(req, context);
} else if (request instanceof TaskResubscriptionRequest req) {
publisher = jsonRpcHandler.onResubscribeToTask(req, context);
} else {
return Multi.createFrom().item(generateErrorResponse(request, new UnsupportedOperationError()));
}
Expand All @@ -234,6 +246,42 @@ static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) {
streamingMultiSseSupportSubscribedRunnable = runnable;
}

private ServerCallContext createCallContext(RoutingContext rc) {

if (callContextFactory.isUnsatisfied()) {
User user;
if (rc.user() == null) {
user = UnauthenticatedUser.INSTANCE;
} else {
user = new User() {
@Override
public boolean isAuthenticated() {
return rc.userContext().authenticated();
}

@Override
public String getUsername() {
return rc.user().subject();
}
};
}
Map<String, Object> state = new HashMap<>();
// TODO Python's impl has
// state['auth'] = request.auth
// in jsonrpc_app.py. Figure out what this maps to in what Vert.X gives us
Comment on lines +269 to +271
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There's a TODO comment here about mapping request.auth from the Python implementation. It's important to investigate what this maps to in Vert.x to ensure the authentication information is correctly propagated. Neglecting this could lead to security issues or incorrect user context.


Map<String, String> headers = new HashMap<>();
Set<String> headerNames = rc.request().headers().names();
headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name)));
state.put("headers", headers);

return new ServerCallContext(user, state);
} else {
CallContextFactory builder = callContextFactory.get();
return builder.build(rc);
}
}

// Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API
private static class MultiSseSupport {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package io.a2a.server.apps.quarkus;

import io.a2a.server.ServerCallContext;
import io.vertx.ext.web.RoutingContext;

public interface CallContextFactory {
ServerCallContext build(RoutingContext rc);
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
package io.a2a.server;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import io.a2a.server.auth.User;

public class ServerCallContext {
// TODO port the fields
// TODO Not totally sure yet about these field types
private final Map<Object, Object> modelConfig = new ConcurrentHashMap<>();
private final Map<String, Object> state;
private final User user;

public ServerCallContext(User user, Map<String, Object> state) {
this.user = user;
this.state = state;
}
Comment on lines +14 to +17
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The state map is stored directly from the constructor argument, and the getter returns the same instance. This allows external code to modify the internal state of ServerCallContext, breaking encapsulation. To make the context immutable, you should store an unmodifiable copy of the map. This will prevent unintended side effects and ensure the context remains consistent throughout the request lifecycle.

Suggested change
public ServerCallContext(User user, Map<String, Object> state) {
this.user = user;
this.state = state;
}
public ServerCallContext(User user, Map<String, Object> state) {
this.user = user;
this.state = state != null ? Map.copyOf(state) : Map.of();
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure at this point whether the intent for this state map is to be mutable or not.


public Map<String, Object> getState() {
return state;
}

public User getUser() {
return user;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,21 @@ public class RequestContext {
private String contextId;
private Task task;
private List<Task> relatedTasks;

public RequestContext(MessageSendParams params, String taskId, String contextId, Task task, List<Task> relatedTasks) throws InvalidParamsError {
private final ServerCallContext callContext;

public RequestContext(
MessageSendParams params,
String taskId,
String contextId,
Task task,
List<Task> relatedTasks,
ServerCallContext callContext) throws InvalidParamsError {
this.params = params;
this.taskId = taskId;
this.contextId = contextId;
this.task = task;
this.relatedTasks = relatedTasks == null ? new ArrayList<>() : relatedTasks;
this.callContext = callContext;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The callContext field is declared as final, ensuring immutability after construction. This is good for preventing unintended modifications to the context within the RequestContext.


// if the taskId and contextId were specified, they must match the params
if (params != null) {
Expand Down Expand Up @@ -73,6 +81,10 @@ public MessageSendConfiguration getConfiguration() {
return params != null ? params.configuration() : null;
}

public ServerCallContext getCallContext() {
return callContext;
}

public String getUserInput(String delimiter) {
if (params == null) {
return "";
Expand Down Expand Up @@ -187,7 +199,7 @@ public ServerCallContext getServerCallContext() {
}

public RequestContext build() {
return new RequestContext(params, taskId, contextId, task, relatedTasks);
return new RequestContext(params, taskId, contextId, task, relatedTasks, serverCallContext);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
package io.a2a.server.auth;

public class UnauthenticatedUser implements User {

public static UnauthenticatedUser INSTANCE = new UnauthenticatedUser();

private UnauthenticatedUser() {
}

@Override
public boolean isAuthenticated() {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import io.a2a.server.ServerCallContext;
import io.a2a.server.agentexecution.AgentExecutor;
import io.a2a.server.agentexecution.RequestContext;
import io.a2a.server.agentexecution.SimpleRequestContextBuilder;
Expand Down Expand Up @@ -87,7 +88,7 @@ public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore,
}

@Override
public Task onGetTask(TaskQueryParams params) throws JSONRPCError {
public Task onGetTask(TaskQueryParams params, ServerCallContext context) throws JSONRPCError {
LOGGER.debug("onGetTask {}", params.id());
Task task = taskStore.get(params.id());
if (task == null) {
Expand All @@ -114,7 +115,7 @@ public Task onGetTask(TaskQueryParams params) throws JSONRPCError {
}

@Override
public Task onCancelTask(TaskIdParams params) throws JSONRPCError {
public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws JSONRPCError {
Task task = taskStore.get(params.id());
if (task == null) {
throw new TaskNotFoundError();
Expand All @@ -136,6 +137,7 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError {
.setTaskId(task.getId())
.setContextId(task.getContextId())
.setTask(task)
.setServerCallContext(context)
.build(),
queue);

Expand All @@ -152,9 +154,9 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError {
}

@Override
public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError {
public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws JSONRPCError {
LOGGER.debug("onMessageSend - task: {}; context {}", params.message().getTaskId(), params.message().getContextId());
MessageSendSetup mss = initMessageSend(params);
MessageSendSetup mss = initMessageSend(params, context);

String taskId = mss.requestContext.getTaskId();
LOGGER.debug("Request context taskId: {}", taskId);
Expand Down Expand Up @@ -200,9 +202,10 @@ public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError {
}

@Override
public Flow.Publisher<StreamingEventKind> onMessageSendStream(MessageSendParams params) throws JSONRPCError {
public Flow.Publisher<StreamingEventKind> onMessageSendStream(
MessageSendParams params, ServerCallContext context) throws JSONRPCError {
LOGGER.debug("onMessageSendStream - task: {}; context {}", params.message().getTaskId(), params.message().getContextId());
MessageSendSetup mss = initMessageSend(params);
MessageSendSetup mss = initMessageSend(params, context);

AtomicReference<String> taskId = new AtomicReference<>(mss.requestContext.getTaskId());
EventQueue queue = queueManager.createOrTap(taskId.get());
Expand Down Expand Up @@ -260,7 +263,8 @@ public Flow.Publisher<StreamingEventKind> onMessageSendStream(MessageSendParams
}

@Override
public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError {
public TaskPushNotificationConfig onSetTaskPushNotificationConfig(
TaskPushNotificationConfig params, ServerCallContext context) throws JSONRPCError {
if (pushConfigStore == null) {
throw new UnsupportedOperationError();
}
Expand All @@ -275,7 +279,8 @@ public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotifi
}

@Override
public TaskPushNotificationConfig onGetTaskPushNotificationConfig(GetTaskPushNotificationConfigParams params) throws JSONRPCError {
public TaskPushNotificationConfig onGetTaskPushNotificationConfig(
GetTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError {
if (pushConfigStore == null) {
throw new UnsupportedOperationError();
}
Expand Down Expand Up @@ -305,7 +310,8 @@ private PushNotificationConfig getPushNotificationConfig(List<PushNotificationCo
}

@Override
public Flow.Publisher<StreamingEventKind> onResubscribeToTask(TaskIdParams params) throws JSONRPCError {
public Flow.Publisher<StreamingEventKind> onResubscribeToTask(
TaskIdParams params, ServerCallContext context) throws JSONRPCError {
Task task = taskStore.get(params.id());
if (task == null) {
throw new TaskNotFoundError();
Expand All @@ -325,7 +331,8 @@ public Flow.Publisher<StreamingEventKind> onResubscribeToTask(TaskIdParams param
}

@Override
public List<TaskPushNotificationConfig> onListTaskPushNotificationConfig(ListTaskPushNotificationConfigParams params) throws JSONRPCError {
public List<TaskPushNotificationConfig> onListTaskPushNotificationConfig(
ListTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError {
if (pushConfigStore == null) {
throw new UnsupportedOperationError();
}
Expand All @@ -347,7 +354,8 @@ public List<TaskPushNotificationConfig> onListTaskPushNotificationConfig(ListTas
}

@Override
public void onDeleteTaskPushNotificationConfig(DeleteTaskPushNotificationConfigParams params) {
public void onDeleteTaskPushNotificationConfig(
DeleteTaskPushNotificationConfigParams params, ServerCallContext context) {
if (pushConfigStore == null) {
throw new UnsupportedOperationError();
}
Expand Down Expand Up @@ -398,7 +406,7 @@ private void cleanupProducer(String taskId) {
});
}

private MessageSendSetup initMessageSend(MessageSendParams params) {
private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallContext context) {
TaskManager taskManager = new TaskManager(
params.message().getTaskId(),
params.message().getContextId(),
Expand All @@ -421,6 +429,7 @@ private MessageSendSetup initMessageSend(MessageSendParams params) {
.setTaskId(task == null ? null : task.getId())
.setContextId(params.message().getContextId())
.setTask(task)
.setServerCallContext(context)
.build();
return new MessageSendSetup(taskManager, task, requestContext);
}
Expand Down
Loading