diff --git a/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java index 81a0df3b3..988736c91 100644 --- a/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java +++ b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/A2AServerRoutes.java @@ -3,6 +3,9 @@ import static io.vertx.core.http.HttpHeaders.CONTENT_TYPE; import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicLong; @@ -18,6 +21,9 @@ import com.fasterxml.jackson.core.io.JsonEOFException; import com.fasterxml.jackson.databind.JsonNode; import io.a2a.server.ExtendedAgentCard; +import io.a2a.server.ServerCallContext; +import io.a2a.server.auth.UnauthenticatedUser; +import io.a2a.server.auth.User; import io.a2a.server.requesthandlers.JSONRPCHandler; import io.a2a.server.util.async.Internal; import io.a2a.spec.AgentCard; @@ -78,9 +84,13 @@ public class A2AServerRoutes { @Internal Executor executor; + @Inject + Instance callContextFactory; + @Route(path = "/", methods = {Route.HttpMethod.POST}, consumes = {APPLICATION_JSON}, type = Route.HandlerType.BLOCKING) public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { boolean streaming = false; + ServerCallContext context = createCallContext(rc); JSONRPCResponse nonStreamingResponse = null; Multi> streamingResponse = null; JSONRPCErrorResponse error = null; @@ -89,10 +99,10 @@ public void invokeJSONRPCHandler(@Body String body, RoutingContext rc) { if (isStreamingRequest(body)) { streaming = true; StreamingJSONRPCRequest request = Utils.OBJECT_MAPPER.readValue(body, StreamingJSONRPCRequest.class); - streamingResponse = processStreamingRequest(request); + streamingResponse = processStreamingRequest(request, context); } else { NonStreamingJSONRPCRequest request = Utils.OBJECT_MAPPER.readValue(body, NonStreamingJSONRPCRequest.class); - nonStreamingResponse = processNonStreamingRequest(request); + nonStreamingResponse = processNonStreamingRequest(request, context); } } catch (JsonProcessingException e) { error = handleError(e); @@ -183,32 +193,34 @@ public void getAuthenticatedExtendedAgentCard(RoutingExchange re) { } } - private JSONRPCResponse processNonStreamingRequest(NonStreamingJSONRPCRequest request) { - if (request instanceof GetTaskRequest) { - return jsonRpcHandler.onGetTask((GetTaskRequest) request); - } else if (request instanceof CancelTaskRequest) { - return jsonRpcHandler.onCancelTask((CancelTaskRequest) request); - } else if (request instanceof SetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.setPushNotificationConfig((SetTaskPushNotificationConfigRequest) request); - } else if (request instanceof GetTaskPushNotificationConfigRequest) { - return jsonRpcHandler.getPushNotificationConfig((GetTaskPushNotificationConfigRequest) request); - } else if (request instanceof SendMessageRequest) { - return jsonRpcHandler.onMessageSend((SendMessageRequest) request); - } else if (request instanceof ListTaskPushNotificationConfigRequest) { - return jsonRpcHandler.listPushNotificationConfig((ListTaskPushNotificationConfigRequest) request); - } else if (request instanceof DeleteTaskPushNotificationConfigRequest) { - return jsonRpcHandler.deletePushNotificationConfig((DeleteTaskPushNotificationConfigRequest) request); + private JSONRPCResponse processNonStreamingRequest( + NonStreamingJSONRPCRequest request, ServerCallContext context) { + if (request instanceof GetTaskRequest req) { + return jsonRpcHandler.onGetTask(req, context); + } else if (request instanceof CancelTaskRequest req) { + return jsonRpcHandler.onCancelTask(req, context); + } else if (request instanceof SetTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.setPushNotificationConfig(req, context); + } else if (request instanceof GetTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.getPushNotificationConfig(req, context); + } else if (request instanceof SendMessageRequest req) { + return jsonRpcHandler.onMessageSend(req, context); + } else if (request instanceof ListTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.listPushNotificationConfig(req, context); + } else if (request instanceof DeleteTaskPushNotificationConfigRequest req) { + return jsonRpcHandler.deletePushNotificationConfig(req, context); } else { return generateErrorResponse(request, new UnsupportedOperationError()); } } - private Multi> processStreamingRequest(JSONRPCRequest request) { + private Multi> processStreamingRequest( + JSONRPCRequest request, ServerCallContext context) { Flow.Publisher> publisher; - if (request instanceof SendStreamingMessageRequest) { - publisher = jsonRpcHandler.onMessageSendStream((SendStreamingMessageRequest) request); - } else if (request instanceof TaskResubscriptionRequest) { - publisher = jsonRpcHandler.onResubscribeToTask((TaskResubscriptionRequest) request); + if (request instanceof SendStreamingMessageRequest req) { + publisher = jsonRpcHandler.onMessageSendStream(req, context); + } else if (request instanceof TaskResubscriptionRequest req) { + publisher = jsonRpcHandler.onResubscribeToTask(req, context); } else { return Multi.createFrom().item(generateErrorResponse(request, new UnsupportedOperationError())); } @@ -234,6 +246,42 @@ static void setStreamingMultiSseSupportSubscribedRunnable(Runnable runnable) { streamingMultiSseSupportSubscribedRunnable = runnable; } + private ServerCallContext createCallContext(RoutingContext rc) { + + if (callContextFactory.isUnsatisfied()) { + User user; + if (rc.user() == null) { + user = UnauthenticatedUser.INSTANCE; + } else { + user = new User() { + @Override + public boolean isAuthenticated() { + return rc.userContext().authenticated(); + } + + @Override + public String getUsername() { + return rc.user().subject(); + } + }; + } + Map state = new HashMap<>(); + // TODO Python's impl has + // state['auth'] = request.auth + // in jsonrpc_app.py. Figure out what this maps to in what Vert.X gives us + + Map headers = new HashMap<>(); + Set headerNames = rc.request().headers().names(); + headerNames.forEach(name -> headers.put(name, rc.request().getHeader(name))); + state.put("headers", headers); + + return new ServerCallContext(user, state); + } else { + CallContextFactory builder = callContextFactory.get(); + return builder.build(rc); + } + } + // Port of import io.quarkus.vertx.web.runtime.MultiSseSupport, which is considered internal API private static class MultiSseSupport { diff --git a/reference-impl/src/main/java/io/a2a/server/apps/quarkus/CallContextFactory.java b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/CallContextFactory.java new file mode 100644 index 000000000..d40bc65f0 --- /dev/null +++ b/reference-impl/src/main/java/io/a2a/server/apps/quarkus/CallContextFactory.java @@ -0,0 +1,8 @@ +package io.a2a.server.apps.quarkus; + +import io.a2a.server.ServerCallContext; +import io.vertx.ext.web.RoutingContext; + +public interface CallContextFactory { + ServerCallContext build(RoutingContext rc); +} diff --git a/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java b/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java index 70fb344df..558f01eda 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java +++ b/sdk-server-common/src/main/java/io/a2a/server/ServerCallContext.java @@ -1,5 +1,26 @@ package io.a2a.server; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import io.a2a.server.auth.User; + public class ServerCallContext { - // TODO port the fields + // TODO Not totally sure yet about these field types + private final Map modelConfig = new ConcurrentHashMap<>(); + private final Map state; + private final User user; + + public ServerCallContext(User user, Map state) { + this.user = user; + this.state = state; + } + + public Map getState() { + return state; + } + + public User getUser() { + return user; + } } diff --git a/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java index bac7673a1..585b4fce4 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java +++ b/sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java @@ -22,13 +22,21 @@ public class RequestContext { private String contextId; private Task task; private List relatedTasks; - - public RequestContext(MessageSendParams params, String taskId, String contextId, Task task, List relatedTasks) throws InvalidParamsError { + private final ServerCallContext callContext; + + public RequestContext( + MessageSendParams params, + String taskId, + String contextId, + Task task, + List relatedTasks, + ServerCallContext callContext) throws InvalidParamsError { this.params = params; this.taskId = taskId; this.contextId = contextId; this.task = task; this.relatedTasks = relatedTasks == null ? new ArrayList<>() : relatedTasks; + this.callContext = callContext; // if the taskId and contextId were specified, they must match the params if (params != null) { @@ -73,6 +81,10 @@ public MessageSendConfiguration getConfiguration() { return params != null ? params.configuration() : null; } + public ServerCallContext getCallContext() { + return callContext; + } + public String getUserInput(String delimiter) { if (params == null) { return ""; @@ -187,7 +199,7 @@ public ServerCallContext getServerCallContext() { } public RequestContext build() { - return new RequestContext(params, taskId, contextId, task, relatedTasks); + return new RequestContext(params, taskId, contextId, task, relatedTasks, serverCallContext); } } diff --git a/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java b/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java index e4ec7a69c..9988ebbcf 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java +++ b/sdk-server-common/src/main/java/io/a2a/server/auth/UnauthenticatedUser.java @@ -1,6 +1,12 @@ package io.a2a.server.auth; public class UnauthenticatedUser implements User { + + public static UnauthenticatedUser INSTANCE = new UnauthenticatedUser(); + + private UnauthenticatedUser() { + } + @Override public boolean isAuthenticated() { return false; diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java index b4dbb2feb..e79a97201 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java @@ -16,6 +16,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; +import io.a2a.server.ServerCallContext; import io.a2a.server.agentexecution.AgentExecutor; import io.a2a.server.agentexecution.RequestContext; import io.a2a.server.agentexecution.SimpleRequestContextBuilder; @@ -87,7 +88,7 @@ public DefaultRequestHandler(AgentExecutor agentExecutor, TaskStore taskStore, } @Override - public Task onGetTask(TaskQueryParams params) throws JSONRPCError { + public Task onGetTask(TaskQueryParams params, ServerCallContext context) throws JSONRPCError { LOGGER.debug("onGetTask {}", params.id()); Task task = taskStore.get(params.id()); if (task == null) { @@ -114,7 +115,7 @@ public Task onGetTask(TaskQueryParams params) throws JSONRPCError { } @Override - public Task onCancelTask(TaskIdParams params) throws JSONRPCError { + public Task onCancelTask(TaskIdParams params, ServerCallContext context) throws JSONRPCError { Task task = taskStore.get(params.id()); if (task == null) { throw new TaskNotFoundError(); @@ -136,6 +137,7 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError { .setTaskId(task.getId()) .setContextId(task.getContextId()) .setTask(task) + .setServerCallContext(context) .build(), queue); @@ -152,9 +154,9 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError { } @Override - public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError { + public EventKind onMessageSend(MessageSendParams params, ServerCallContext context) throws JSONRPCError { LOGGER.debug("onMessageSend - task: {}; context {}", params.message().getTaskId(), params.message().getContextId()); - MessageSendSetup mss = initMessageSend(params); + MessageSendSetup mss = initMessageSend(params, context); String taskId = mss.requestContext.getTaskId(); LOGGER.debug("Request context taskId: {}", taskId); @@ -200,9 +202,10 @@ public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError { } @Override - public Flow.Publisher onMessageSendStream(MessageSendParams params) throws JSONRPCError { + public Flow.Publisher onMessageSendStream( + MessageSendParams params, ServerCallContext context) throws JSONRPCError { LOGGER.debug("onMessageSendStream - task: {}; context {}", params.message().getTaskId(), params.message().getContextId()); - MessageSendSetup mss = initMessageSend(params); + MessageSendSetup mss = initMessageSend(params, context); AtomicReference taskId = new AtomicReference<>(mss.requestContext.getTaskId()); EventQueue queue = queueManager.createOrTap(taskId.get()); @@ -260,7 +263,8 @@ public Flow.Publisher onMessageSendStream(MessageSendParams } @Override - public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError { + public TaskPushNotificationConfig onSetTaskPushNotificationConfig( + TaskPushNotificationConfig params, ServerCallContext context) throws JSONRPCError { if (pushConfigStore == null) { throw new UnsupportedOperationError(); } @@ -275,7 +279,8 @@ public TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotifi } @Override - public TaskPushNotificationConfig onGetTaskPushNotificationConfig(GetTaskPushNotificationConfigParams params) throws JSONRPCError { + public TaskPushNotificationConfig onGetTaskPushNotificationConfig( + GetTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError { if (pushConfigStore == null) { throw new UnsupportedOperationError(); } @@ -305,7 +310,8 @@ private PushNotificationConfig getPushNotificationConfig(List onResubscribeToTask(TaskIdParams params) throws JSONRPCError { + public Flow.Publisher onResubscribeToTask( + TaskIdParams params, ServerCallContext context) throws JSONRPCError { Task task = taskStore.get(params.id()); if (task == null) { throw new TaskNotFoundError(); @@ -325,7 +331,8 @@ public Flow.Publisher onResubscribeToTask(TaskIdParams param } @Override - public List onListTaskPushNotificationConfig(ListTaskPushNotificationConfigParams params) throws JSONRPCError { + public List onListTaskPushNotificationConfig( + ListTaskPushNotificationConfigParams params, ServerCallContext context) throws JSONRPCError { if (pushConfigStore == null) { throw new UnsupportedOperationError(); } @@ -347,7 +354,8 @@ public List onListTaskPushNotificationConfig(ListTas } @Override - public void onDeleteTaskPushNotificationConfig(DeleteTaskPushNotificationConfigParams params) { + public void onDeleteTaskPushNotificationConfig( + DeleteTaskPushNotificationConfigParams params, ServerCallContext context) { if (pushConfigStore == null) { throw new UnsupportedOperationError(); } @@ -398,7 +406,7 @@ private void cleanupProducer(String taskId) { }); } - private MessageSendSetup initMessageSend(MessageSendParams params) { + private MessageSendSetup initMessageSend(MessageSendParams params, ServerCallContext context) { TaskManager taskManager = new TaskManager( params.message().getTaskId(), params.message().getContextId(), @@ -421,6 +429,7 @@ private MessageSendSetup initMessageSend(MessageSendParams params) { .setTaskId(task == null ? null : task.getId()) .setContextId(params.message().getContextId()) .setTask(task) + .setServerCallContext(context) .build(); return new MessageSendSetup(taskManager, task, requestContext); } diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java index 6a394e023..fb120f981 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/JSONRPCHandler.java @@ -2,12 +2,14 @@ import static io.a2a.server.util.async.AsyncUtils.createTubeConfig; import jakarta.enterprise.context.ApplicationScoped; +import jakarta.enterprise.inject.Instance; import jakarta.inject.Inject; import java.util.List; import java.util.concurrent.Flow; import io.a2a.server.PublicAgentCard; +import io.a2a.server.ServerCallContext; import io.a2a.spec.AgentCard; import io.a2a.spec.CancelTaskRequest; import io.a2a.spec.CancelTaskResponse; @@ -52,9 +54,9 @@ public JSONRPCHandler(@PublicAgentCard AgentCard agentCard, RequestHandler reque this.requestHandler = requestHandler; } - public SendMessageResponse onMessageSend(SendMessageRequest request) { + public SendMessageResponse onMessageSend(SendMessageRequest request, ServerCallContext context) { try { - EventKind taskOrMessage = requestHandler.onMessageSend(request.getParams()); + EventKind taskOrMessage = requestHandler.onMessageSend(request.getParams(), context); return new SendMessageResponse(request.getId(), taskOrMessage); } catch (JSONRPCError e) { return new SendMessageResponse(request.getId(), e); @@ -64,7 +66,8 @@ public SendMessageResponse onMessageSend(SendMessageRequest request) { } - public Flow.Publisher onMessageSendStream(SendStreamingMessageRequest request) { + public Flow.Publisher onMessageSendStream( + SendStreamingMessageRequest request, ServerCallContext context) { if (!agentCard.capabilities().streaming()) { return ZeroPublisher.fromItems( new SendStreamingMessageResponse( @@ -73,7 +76,8 @@ public Flow.Publisher onMessageSendStream(SendStre } try { - Flow.Publisher publisher = requestHandler.onMessageSendStream(request.getParams()); + Flow.Publisher publisher = + requestHandler.onMessageSendStream(request.getParams(), context); // We can't use the convertingProcessor convenience method since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload return convertToSendStreamingMessageResponse(request.getId(), publisher); @@ -84,9 +88,9 @@ public Flow.Publisher onMessageSendStream(SendStre } } - public CancelTaskResponse onCancelTask(CancelTaskRequest request) { + public CancelTaskResponse onCancelTask(CancelTaskRequest request, ServerCallContext context) { try { - Task task = requestHandler.onCancelTask(request.getParams()); + Task task = requestHandler.onCancelTask(request.getParams(), context); if (task != null) { return new CancelTaskResponse(request.getId(), task); } @@ -98,7 +102,8 @@ public CancelTaskResponse onCancelTask(CancelTaskRequest request) { } } - public Flow.Publisher onResubscribeToTask(TaskResubscriptionRequest request) { + public Flow.Publisher onResubscribeToTask( + TaskResubscriptionRequest request, ServerCallContext context) { if (!agentCard.capabilities().streaming()) { return ZeroPublisher.fromItems( new SendStreamingMessageResponse( @@ -107,7 +112,8 @@ public Flow.Publisher onResubscribeToTask(TaskResu } try { - Flow.Publisher publisher = requestHandler.onResubscribeToTask(request.getParams()); + Flow.Publisher publisher = + requestHandler.onResubscribeToTask(request.getParams(), context); // We can't use the convertingProcessor convenience method since that propagates any errors as an error handled // via Subscriber.onError() rather than as part of the SendStreamingResponse payload return convertToSendStreamingMessageResponse(request.getId(), publisher); @@ -118,13 +124,15 @@ public Flow.Publisher onResubscribeToTask(TaskResu } } - public GetTaskPushNotificationConfigResponse getPushNotificationConfig(GetTaskPushNotificationConfigRequest request) { + public GetTaskPushNotificationConfigResponse getPushNotificationConfig( + GetTaskPushNotificationConfigRequest request, ServerCallContext context) { if (!agentCard.capabilities().pushNotifications()) { return new GetTaskPushNotificationConfigResponse(request.getId(), new PushNotificationNotSupportedError()); } try { - TaskPushNotificationConfig config = requestHandler.onGetTaskPushNotificationConfig(request.getParams()); + TaskPushNotificationConfig config = + requestHandler.onGetTaskPushNotificationConfig(request.getParams(), context); return new GetTaskPushNotificationConfigResponse(request.getId(), config); } catch (JSONRPCError e) { return new GetTaskPushNotificationConfigResponse(request.getId().toString(), e); @@ -133,13 +141,15 @@ public GetTaskPushNotificationConfigResponse getPushNotificationConfig(GetTaskPu } } - public SetTaskPushNotificationConfigResponse setPushNotificationConfig(SetTaskPushNotificationConfigRequest request) { + public SetTaskPushNotificationConfigResponse setPushNotificationConfig( + SetTaskPushNotificationConfigRequest request, ServerCallContext context) { if (!agentCard.capabilities().pushNotifications()) { return new SetTaskPushNotificationConfigResponse(request.getId(), new PushNotificationNotSupportedError()); } try { - TaskPushNotificationConfig config = requestHandler.onSetTaskPushNotificationConfig(request.getParams()); + TaskPushNotificationConfig config = + requestHandler.onSetTaskPushNotificationConfig(request.getParams(), context); return new SetTaskPushNotificationConfigResponse(request.getId().toString(), config); } catch (JSONRPCError e) { return new SetTaskPushNotificationConfigResponse(request.getId(), e); @@ -148,9 +158,9 @@ public SetTaskPushNotificationConfigResponse setPushNotificationConfig(SetTaskPu } } - public GetTaskResponse onGetTask(GetTaskRequest request) { + public GetTaskResponse onGetTask(GetTaskRequest request, ServerCallContext context) { try { - Task task = requestHandler.onGetTask(request.getParams()); + Task task = requestHandler.onGetTask(request.getParams(), context); return new GetTaskResponse(request.getId(), task); } catch (JSONRPCError e) { return new GetTaskResponse(request.getId(), e); @@ -159,13 +169,15 @@ public GetTaskResponse onGetTask(GetTaskRequest request) { } } - public ListTaskPushNotificationConfigResponse listPushNotificationConfig(ListTaskPushNotificationConfigRequest request) { + public ListTaskPushNotificationConfigResponse listPushNotificationConfig( + ListTaskPushNotificationConfigRequest request, ServerCallContext context) { if ( !agentCard.capabilities().pushNotifications()) { return new ListTaskPushNotificationConfigResponse(request.getId(), new PushNotificationNotSupportedError()); } try { - List pushNotificationConfigList = requestHandler.onListTaskPushNotificationConfig(request.getParams()); + List pushNotificationConfigList = + requestHandler.onListTaskPushNotificationConfig(request.getParams(), context); return new ListTaskPushNotificationConfigResponse(request.getId(), pushNotificationConfigList); } catch (JSONRPCError e) { return new ListTaskPushNotificationConfigResponse(request.getId(), e); @@ -174,13 +186,14 @@ public ListTaskPushNotificationConfigResponse listPushNotificationConfig(ListTas } } - public DeleteTaskPushNotificationConfigResponse deletePushNotificationConfig(DeleteTaskPushNotificationConfigRequest request) { + public DeleteTaskPushNotificationConfigResponse deletePushNotificationConfig( + DeleteTaskPushNotificationConfigRequest request, ServerCallContext context) { if ( !agentCard.capabilities().pushNotifications()) { return new DeleteTaskPushNotificationConfigResponse(request.getId(), new PushNotificationNotSupportedError()); } try { - requestHandler.onDeleteTaskPushNotificationConfig(request.getParams()); + requestHandler.onDeleteTaskPushNotificationConfig(request.getParams(), context); return new DeleteTaskPushNotificationConfigResponse(request.getId()); } catch (JSONRPCError e) { return new DeleteTaskPushNotificationConfigResponse(request.getId(), e); diff --git a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java index 66d07b55b..e45bc3c62 100644 --- a/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java +++ b/sdk-server-common/src/main/java/io/a2a/server/requesthandlers/RequestHandler.java @@ -3,6 +3,7 @@ import java.util.List; import java.util.concurrent.Flow; +import io.a2a.server.ServerCallContext; import io.a2a.spec.DeleteTaskPushNotificationConfigParams; import io.a2a.spec.EventKind; import io.a2a.spec.GetTaskPushNotificationConfigParams; @@ -16,21 +17,39 @@ import io.a2a.spec.TaskQueryParams; public interface RequestHandler { - Task onGetTask(TaskQueryParams params) throws JSONRPCError; + Task onGetTask( + TaskQueryParams params, + ServerCallContext context) throws JSONRPCError; - Task onCancelTask(TaskIdParams params) throws JSONRPCError; + Task onCancelTask( + TaskIdParams params, + ServerCallContext context) throws JSONRPCError; - EventKind onMessageSend(MessageSendParams params) throws JSONRPCError; + EventKind onMessageSend( + MessageSendParams params, + ServerCallContext context) throws JSONRPCError; - Flow.Publisher onMessageSendStream(MessageSendParams params) throws JSONRPCError; + Flow.Publisher onMessageSendStream( + MessageSendParams params, + ServerCallContext context) throws JSONRPCError; - TaskPushNotificationConfig onSetTaskPushNotificationConfig(TaskPushNotificationConfig params) throws JSONRPCError; + TaskPushNotificationConfig onSetTaskPushNotificationConfig( + TaskPushNotificationConfig params, + ServerCallContext context) throws JSONRPCError; - TaskPushNotificationConfig onGetTaskPushNotificationConfig(GetTaskPushNotificationConfigParams params) throws JSONRPCError; + TaskPushNotificationConfig onGetTaskPushNotificationConfig( + GetTaskPushNotificationConfigParams params, + ServerCallContext context) throws JSONRPCError; - Flow.Publisher onResubscribeToTask(TaskIdParams params) throws JSONRPCError; + Flow.Publisher onResubscribeToTask( + TaskIdParams params, + ServerCallContext context) throws JSONRPCError; - List onListTaskPushNotificationConfig(ListTaskPushNotificationConfigParams params) throws JSONRPCError; + List onListTaskPushNotificationConfig( + ListTaskPushNotificationConfigParams params, + ServerCallContext context) throws JSONRPCError; - void onDeleteTaskPushNotificationConfig(DeleteTaskPushNotificationConfigParams params) throws JSONRPCError; + void onDeleteTaskPushNotificationConfig( + DeleteTaskPushNotificationConfigParams params, + ServerCallContext context) throws JSONRPCError; } diff --git a/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java b/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java index c9bb79061..081cc873a 100644 --- a/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java +++ b/sdk-server-common/src/test/java/io/a2a/server/agentexecution/RequestContextTest.java @@ -25,7 +25,7 @@ public class RequestContextTest { @Test public void testInitWithoutParams() { - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertNull(context.getMessage()); assertNull(context.getTaskId()); assertNull(context.getContextId()); @@ -46,7 +46,7 @@ public void testInitWithParamsNoIds() { .thenReturn(taskId) .thenReturn(contextId); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(mockParams.message(), context.getMessage()); assertEquals(taskId.toString(), context.getTaskId()); @@ -62,7 +62,7 @@ public void testInitWithTaskId() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, taskId, null, null, null); + RequestContext context = new RequestContext(mockParams, taskId, null, null, null, null); assertEquals(taskId, context.getTaskId()); assertEquals(taskId, mockParams.message().getTaskId()); @@ -73,7 +73,7 @@ public void testInitWithContextId() { String contextId = "context-456"; var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(contextId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, contextId, null, null); + RequestContext context = new RequestContext(mockParams, null, contextId, null, null, null); assertEquals(contextId, context.getContextId()); assertEquals(contextId, mockParams.message().getContextId()); @@ -85,7 +85,7 @@ public void testInitWithBothIds() { String contextId = "context-456"; var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).contextId(contextId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, taskId, contextId, null, null); + RequestContext context = new RequestContext(mockParams, taskId, contextId, null, null, null); assertEquals(taskId, context.getTaskId()); assertEquals(taskId, mockParams.message().getTaskId()); @@ -99,14 +99,14 @@ public void testInitWithTask() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, mockTask, null); + RequestContext context = new RequestContext(mockParams, null, null, mockTask, null, null); assertEquals(mockTask, context.getTask()); } @Test public void testGetUserInputNoParams() { - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertEquals("", context.getUserInput(null)); } @@ -114,7 +114,7 @@ public void testGetUserInputNoParams() { public void testAttachRelatedTask() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertEquals(0, context.getRelatedTasks().size()); context.attachRelatedTask(mockTask); @@ -133,7 +133,7 @@ public void testCheckOrGenerateTaskIdWithExistingTaskId() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(existingId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(existingId, context.getTaskId()); assertEquals(existingId, mockParams.message().getTaskId()); @@ -146,7 +146,7 @@ public void testCheckOrGenerateContextIdWithExistingContextId() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(existingId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(existingId, context.getContextId()); assertEquals(existingId, mockParams.message().getContextId()); @@ -159,7 +159,7 @@ public void testInitRaisesErrorOnTaskIdMismatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> - new RequestContext(mockParams, "wrong-task-id", null, mockTask, null)); + new RequestContext(mockParams, "wrong-task-id", null, mockTask, null, null)); assertTrue(error.getMessage().contains("bad task id")); } @@ -171,7 +171,7 @@ public void testInitRaisesErrorOnContextIdMismatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); InvalidParamsError error = assertThrows(InvalidParamsError.class, () -> - new RequestContext(mockParams, mockTask.getId(), "wrong-context-id", mockTask, null)); + new RequestContext(mockParams, mockTask.getId(), "wrong-context-id", mockTask, null, null)); assertTrue(error.getMessage().contains("bad context id")); } @@ -184,7 +184,7 @@ public void testWithRelatedTasksProvided() { relatedTasks.add(mockTask); relatedTasks.add(mock(Task.class)); - RequestContext context = new RequestContext(null, null, null, null, relatedTasks); + RequestContext context = new RequestContext(null, null, null, null, relatedTasks, null); assertEquals(relatedTasks, context.getRelatedTasks()); assertEquals(2, context.getRelatedTasks().size()); @@ -192,7 +192,7 @@ public void testWithRelatedTasksProvided() { @Test public void testMessagePropertyWithoutParams() { - RequestContext context = new RequestContext(null, null, null, null, null); + RequestContext context = new RequestContext(null, null, null, null, null, null); assertNull(context.getMessage()); } @@ -201,7 +201,7 @@ public void testMessagePropertyWithParams() { var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(mockParams.message(), context.getMessage()); } @@ -214,7 +214,7 @@ public void testInitWithExistingIdsInMessage() { .taskId(existingTaskId).contextId(existingContextId).build(); var mockParams = new MessageSendParams.Builder().message(mockMessage).build(); - RequestContext context = new RequestContext(mockParams, null, null, null, null); + RequestContext context = new RequestContext(mockParams, null, null, null, null, null); assertEquals(existingTaskId, context.getTaskId()); assertEquals(existingContextId, context.getContextId()); @@ -227,7 +227,7 @@ public void testInitWithTaskIdAndExistingTaskIdMatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - RequestContext context = new RequestContext(mockParams, mockTask.getId(), null, mockTask, null); + RequestContext context = new RequestContext(mockParams, mockTask.getId(), null, mockTask, null, null); assertEquals(mockTask.getId(), context.getTaskId()); assertEquals(mockTask, context.getTask()); @@ -240,7 +240,7 @@ public void testInitWithContextIdAndExistingContextIdMatch() { var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build(); - RequestContext context = new RequestContext(mockParams, mockTask.getId(), mockTask.getContextId(), mockTask, null); + RequestContext context = new RequestContext(mockParams, mockTask.getId(), mockTask.getContextId(), mockTask, null, null); assertEquals(mockTask.getContextId(), context.getContextId()); assertEquals(mockTask, context.getTask()); diff --git a/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java index 80eaacf9e..f67fa87cf 100644 --- a/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java +++ b/sdk-server-common/src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java @@ -7,12 +7,12 @@ import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -import jakarta.enterprise.context.Dependent; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -22,10 +22,14 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import jakarta.enterprise.context.Dependent; + import io.a2a.http.A2AHttpClient; import io.a2a.http.A2AHttpResponse; +import io.a2a.server.ServerCallContext; import io.a2a.server.agentexecution.AgentExecutor; import io.a2a.server.agentexecution.RequestContext; +import io.a2a.server.auth.UnauthenticatedUser; import io.a2a.server.events.EventConsumer; import io.a2a.server.events.EventQueue; import io.a2a.server.events.InMemoryQueueManager; @@ -83,7 +87,6 @@ import io.a2a.util.Utils; import io.quarkus.arc.profile.IfBuildProfile; import mutiny.zero.ZeroPublisher; - import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; @@ -117,6 +120,8 @@ public class JSONRPCHandlerTest { private final Executor internalExecutor = Executors.newCachedThreadPool(); + private final ServerCallContext callContext = new ServerCallContext(UnauthenticatedUser.INSTANCE, Map.of("foo", "bar")); + @BeforeEach public void init() { @@ -156,7 +161,7 @@ public void testOnGetTaskSuccess() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); taskStore.save(MINIMAL_TASK); GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); - GetTaskResponse response = handler.onGetTask(request); + GetTaskResponse response = handler.onGetTask(request, callContext); assertEquals(request.getId(), response.getId()); assertSame(MINIMAL_TASK, response.getResult()); assertNull(response.getError()); @@ -166,7 +171,7 @@ public void testOnGetTaskSuccess() throws Exception { public void testOnGetTaskNotFound() throws Exception { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); GetTaskRequest request = new GetTaskRequest("1", new TaskQueryParams(MINIMAL_TASK.getId())); - GetTaskResponse response = handler.onGetTask(request); + GetTaskResponse response = handler.onGetTask(request, callContext); assertEquals(request.getId(), response.getId()); assertInstanceOf(TaskNotFoundError.class, response.getError()); assertNull(response.getResult()); @@ -187,7 +192,7 @@ public void testOnCancelTaskSuccess() throws Exception { }; CancelTaskRequest request = new CancelTaskRequest("111", new TaskIdParams(MINIMAL_TASK.getId())); - CancelTaskResponse response = handler.onCancelTask(request); + CancelTaskResponse response = handler.onCancelTask(request, callContext); assertNull(response.getError()); assertEquals(request.getId(), response.getId()); @@ -207,7 +212,7 @@ public void testOnCancelTaskNotSupported() { }; CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - CancelTaskResponse response = handler.onCancelTask(request); + CancelTaskResponse response = handler.onCancelTask(request, callContext); assertEquals(request.getId(), response.getId()); assertNull(response.getResult()); assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -217,7 +222,7 @@ public void testOnCancelTaskNotSupported() { public void testOnCancelTaskNotFound() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); CancelTaskRequest request = new CancelTaskRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - CancelTaskResponse response = handler.onCancelTask(request); + CancelTaskResponse response = handler.onCancelTask(request, callContext); assertEquals(request.getId(), response.getId()); assertNull(response.getResult()); assertInstanceOf(TaskNotFoundError.class, response.getError()); @@ -234,7 +239,7 @@ public void testOnMessageNewMessageSuccess() { .contextId(MINIMAL_TASK.getContextId()) .build(); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); // The Python implementation returns a Task here, but then again they are using hardcoded mocks and // bypassing the whole EventQueue. @@ -259,7 +264,7 @@ public void testOnMessageNewMessageSuccessMocks() { try (MockedConstruction mocked = Mockito.mockConstruction( EventConsumer.class, (mock, context) -> {Mockito.doReturn(ZeroPublisher.fromItems(MINIMAL_TASK)).when(mock).consumeAll();})){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertNull(response.getError()); assertSame(MINIMAL_TASK, response.getResult()); @@ -277,7 +282,7 @@ public void testOnMessageNewMessageWithExistingTaskSuccess() { .contextId(MINIMAL_TASK.getContextId()) .build(); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(message, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertNull(response.getError()); // The Python implementation returns a Task here, but then again they are using hardcoded mocks and // bypassing the whole EventQueue. @@ -303,7 +308,7 @@ public void testOnMessageNewMessageWithExistingTaskSuccessMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromItems(MINIMAL_TASK)).when(mock).consumeAll();})){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertNull(response.getError()); assertSame(MINIMAL_TASK, response.getResult()); @@ -324,7 +329,7 @@ public void testOnMessageError() { .build(); SendMessageRequest request = new SendMessageRequest( "1", new MessageSendParams(message, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(UnsupportedOperationError.class, response.getError()); assertNull(response.getResult()); } @@ -343,7 +348,7 @@ public void testOnMessageErrorMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromItems(new UnsupportedOperationError())).when(mock).consumeAll();})){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -364,7 +369,7 @@ public void testOnMessageStreamNewMessageSuccess() { SendStreamingMessageRequest request = new SendStreamingMessageRequest( "1", new MessageSendParams(message, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); CountDownLatch latch = new CountDownLatch(1); @@ -438,7 +443,7 @@ public void testOnMessageStreamNewMessageSuccessMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ - response = handler.onMessageSendStream(request); + response = handler.onMessageSendStream(request, callContext); } List results = new ArrayList<>(); @@ -492,7 +497,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccess() throws Exception SendStreamingMessageRequest request = new SendStreamingMessageRequest( "1", new MessageSendParams(message, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); // This Publisher never completes so we subscribe in a new thread. // I _think_ that is as expected, and testOnMessageStreamNewMessageSendPushNotificationSuccess seems @@ -583,7 +588,7 @@ public void testOnMessageStreamNewMessageExistingTaskSuccessMocks() { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ - response = handler.onMessageSendStream(request); + response = handler.onMessageSendStream(request, callContext); } List results = new ArrayList<>(); @@ -630,7 +635,7 @@ public void testSetPushNotificationConfigSuccess() { new TaskPushNotificationConfig( MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request); + SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request, callContext); assertSame(taskPushConfig, response.getResult()); } @@ -648,11 +653,11 @@ public void testGetPushNotificationConfigSuccess() { MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - handler.setPushNotificationConfig(request); + handler.setPushNotificationConfig(request, callContext); GetTaskPushNotificationConfigRequest getRequest = new GetTaskPushNotificationConfigRequest("111", new GetTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); - GetTaskPushNotificationConfigResponse getResponse = handler.getPushNotificationConfig(getRequest); + GetTaskPushNotificationConfigResponse getResponse = handler.getPushNotificationConfig(getRequest, callContext); TaskPushNotificationConfig expectedConfig = new TaskPushNotificationConfig(MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().id(MINIMAL_TASK.getId()).url("http://example.com").build()); @@ -693,14 +698,14 @@ public void testOnMessageStreamNewMessageSendPushNotificationSuccess() throws Ex MINIMAL_TASK.getId(), new PushNotificationConfig.Builder().url("http://example.com").build()); SetTaskPushNotificationConfigRequest stpnRequest = new SetTaskPushNotificationConfigRequest("1", config); - SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest); + SetTaskPushNotificationConfigResponse stpnResponse = handler.setPushNotificationConfig(stpnRequest, callContext); assertNull(stpnResponse.getError()); Message msg = new Message.Builder(MESSAGE) .taskId(MINIMAL_TASK.getId()) .build(); SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(msg, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); final List results = Collections.synchronizedList(new ArrayList<>()); final AtomicReference subscriptionRef = new AtomicReference<>(); @@ -777,7 +782,7 @@ public void testOnResubscribeExistingTaskSuccess() { }; TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - Flow.Publisher response = handler.onResubscribeToTask(request); + Flow.Publisher response = handler.onResubscribeToTask(request, callContext); // We need to send some events in order for those to end up in the queue Message message = new Message.Builder() @@ -787,7 +792,9 @@ public void testOnResubscribeExistingTaskSuccess() { .parts(new TextPart("text")) .build(); SendMessageResponse smr = - handler.onMessageSend(new SendMessageRequest("1", new MessageSendParams(message, null, null))); + handler.onMessageSend( + new SendMessageRequest("1", new MessageSendParams(message, null, null)), + callContext); assertNull(smr.getError()); @@ -853,7 +860,7 @@ public void testOnResubscribeExistingTaskSuccessMocks() throws Exception { EventConsumer.class, (mock, context) -> { Mockito.doReturn(ZeroPublisher.fromIterable(events)).when(mock).consumeAll();})){ - response = handler.onResubscribeToTask(request); + response = handler.onResubscribeToTask(request, callContext); } List results = new ArrayList<>(); @@ -898,7 +905,7 @@ public void testOnResubscribeNoExistingTaskError() { JSONRPCHandler handler = new JSONRPCHandler(CARD, requestHandler); TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - Flow.Publisher response = handler.onResubscribeToTask(request); + Flow.Publisher response = handler.onResubscribeToTask(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -946,7 +953,7 @@ public void testStreamingNotSupportedError() { .message(MESSAGE) .build()) .build(); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -992,7 +999,7 @@ public void testStreamingNotSupportedErrorOnResubscribeToTask() { JSONRPCHandler handler = new JSONRPCHandler(card, requestHandler); TaskResubscriptionRequest request = new TaskResubscriptionRequest("1", new TaskIdParams(MINIMAL_TASK.getId())); - Flow.Publisher response = handler.onResubscribeToTask(request); + Flow.Publisher response = handler.onResubscribeToTask(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -1048,7 +1055,7 @@ public void testPushNotificationsNotSupportedError() { SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest.Builder() .params(config) .build(); - SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request); + SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request, callContext); assertInstanceOf(PushNotificationNotSupportedError.class, response.getError()); } @@ -1064,7 +1071,7 @@ public void testOnGetPushNotificationNoPushNotifierConfig() { GetTaskPushNotificationConfigRequest request = new GetTaskPushNotificationConfigRequest("id", new GetTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); - GetTaskPushNotificationConfigResponse response = handler.getPushNotificationConfig(request); + GetTaskPushNotificationConfigResponse response = handler.getPushNotificationConfig(request, callContext); assertNotNull(response.getError()); assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -1091,7 +1098,7 @@ public void testOnSetPushNotificationNoPushNotifierConfig() { SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest.Builder() .params(config) .build(); - SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request); + SetTaskPushNotificationConfigResponse response = handler.setPushNotificationConfig(request, callContext); assertInstanceOf(UnsupportedOperationError.class, response.getError()); assertEquals("This operation is not supported", response.getError().getMessage()); @@ -1100,12 +1107,13 @@ public void testOnSetPushNotificationNoPushNotifierConfig() { @Test public void testOnMessageSendInternalError() { DefaultRequestHandler mocked = Mockito.mock(DefaultRequestHandler.class); - Mockito.doThrow(new InternalError("Internal Error")).when(mocked).onMessageSend(Mockito.any(MessageSendParams.class)); + Mockito.doThrow(new InternalError("Internal Error")).when(mocked) + .onMessageSend(Mockito.any(MessageSendParams.class), Mockito.any(ServerCallContext.class)); JSONRPCHandler handler = new JSONRPCHandler(CARD, mocked); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(InternalError.class, response.getError()); } @@ -1113,12 +1121,13 @@ public void testOnMessageSendInternalError() { @Test public void testOnMessageStreamInternalError() { DefaultRequestHandler mocked = Mockito.mock(DefaultRequestHandler.class); - Mockito.doThrow(new InternalError("Internal Error")).when(mocked).onMessageSendStream(Mockito.any(MessageSendParams.class)); + Mockito.doThrow(new InternalError("Internal Error")).when(mocked) + .onMessageSendStream(Mockito.any(MessageSendParams.class), Mockito.any(ServerCallContext.class)); JSONRPCHandler handler = new JSONRPCHandler(CARD, mocked); SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); @@ -1184,7 +1193,7 @@ public void testOnMessageSendErrorHandling() { Mockito.doThrow( new UnsupportedOperationError()) .when(mock).consumeAndBreakOnInterrupt(Mockito.any(EventConsumer.class)))){ - response = handler.onMessageSend(request); + response = handler.onMessageSend(request, callContext); } assertInstanceOf(UnsupportedOperationError.class, response.getError()); @@ -1201,7 +1210,7 @@ public void testOnMessageSendTaskIdMismatch() { }); SendMessageRequest request = new SendMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - SendMessageResponse response = handler.onMessageSend(request); + SendMessageResponse response = handler.onMessageSend(request, callContext); assertInstanceOf(InternalError.class, response.getError()); } @@ -1216,7 +1225,7 @@ public void testOnMessageStreamTaskIdMismatch() { }); SendStreamingMessageRequest request = new SendStreamingMessageRequest("1", new MessageSendParams(MESSAGE, null, null)); - Flow.Publisher response = handler.onMessageSendStream(request); + Flow.Publisher response = handler.onMessageSendStream(request, callContext); List results = new ArrayList<>(); AtomicReference error = new AtomicReference<>(); @@ -1267,11 +1276,11 @@ public void testListPushNotificationConfig() { .id(MINIMAL_TASK.getId()) .build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - handler.setPushNotificationConfig(request); + handler.setPushNotificationConfig(request, callContext); ListTaskPushNotificationConfigRequest listRequest = new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); - ListTaskPushNotificationConfigResponse listResponse = handler.listPushNotificationConfig(listRequest); + ListTaskPushNotificationConfigResponse listResponse = handler.listPushNotificationConfig(listRequest, callContext); assertEquals("111", listResponse.getId()); assertEquals(1, listResponse.getResult().size()); @@ -1294,11 +1303,12 @@ public void testListPushNotificationConfigNotSupported() { .id(MINIMAL_TASK.getId()) .build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - handler.setPushNotificationConfig(request); + handler.setPushNotificationConfig(request, callContext); ListTaskPushNotificationConfigRequest listRequest = new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); - ListTaskPushNotificationConfigResponse listResponse = handler.listPushNotificationConfig(listRequest); + ListTaskPushNotificationConfigResponse listResponse = + handler.listPushNotificationConfig(listRequest, callContext); assertEquals("111", listResponse.getId()); assertNull(listResponse.getResult()); @@ -1317,7 +1327,8 @@ public void testListPushNotificationConfigNoPushConfigStore() { ListTaskPushNotificationConfigRequest listRequest = new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); - ListTaskPushNotificationConfigResponse listResponse = handler.listPushNotificationConfig(listRequest); + ListTaskPushNotificationConfigResponse listResponse = + handler.listPushNotificationConfig(listRequest, callContext); assertEquals("111", listResponse.getId()); assertNull(listResponse.getResult()); @@ -1333,7 +1344,8 @@ public void testListPushNotificationConfigTaskNotFound() { ListTaskPushNotificationConfigRequest listRequest = new ListTaskPushNotificationConfigRequest("111", new ListTaskPushNotificationConfigParams(MINIMAL_TASK.getId())); - ListTaskPushNotificationConfigResponse listResponse = handler.listPushNotificationConfig(listRequest); + ListTaskPushNotificationConfigResponse listResponse = + handler.listPushNotificationConfig(listRequest, callContext); assertEquals("111", listResponse.getId()); assertNull(listResponse.getResult()); @@ -1355,11 +1367,12 @@ public void testDeletePushNotificationConfig() { .id(MINIMAL_TASK.getId()) .build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - handler.setPushNotificationConfig(request); + handler.setPushNotificationConfig(request, callContext); DeleteTaskPushNotificationConfigRequest deleteRequest = new DeleteTaskPushNotificationConfigRequest("111", new DeleteTaskPushNotificationConfigParams(MINIMAL_TASK.getId(), MINIMAL_TASK.getId())); - DeleteTaskPushNotificationConfigResponse deleteResponse = handler.deletePushNotificationConfig(deleteRequest); + DeleteTaskPushNotificationConfigResponse deleteResponse = + handler.deletePushNotificationConfig(deleteRequest, callContext); assertEquals("111", deleteResponse.getId()); assertNull(deleteResponse.getError()); @@ -1382,11 +1395,12 @@ public void testDeletePushNotificationConfigNotSupported() { .id(MINIMAL_TASK.getId()) .build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - handler.setPushNotificationConfig(request); + handler.setPushNotificationConfig(request, callContext); DeleteTaskPushNotificationConfigRequest deleteRequest = new DeleteTaskPushNotificationConfigRequest("111", new DeleteTaskPushNotificationConfigParams(MINIMAL_TASK.getId(), MINIMAL_TASK.getId())); - DeleteTaskPushNotificationConfigResponse deleteResponse = handler.deletePushNotificationConfig(deleteRequest); + DeleteTaskPushNotificationConfigResponse deleteResponse = + handler.deletePushNotificationConfig(deleteRequest, callContext); assertEquals("111", deleteResponse.getId()); assertNull(deleteResponse.getResult()); @@ -1410,11 +1424,12 @@ public void testDeletePushNotificationConfigNoPushConfigStore() { .id(MINIMAL_TASK.getId()) .build()); SetTaskPushNotificationConfigRequest request = new SetTaskPushNotificationConfigRequest("1", taskPushConfig); - handler.setPushNotificationConfig(request); + handler.setPushNotificationConfig(request, callContext); DeleteTaskPushNotificationConfigRequest deleteRequest = new DeleteTaskPushNotificationConfigRequest("111", new DeleteTaskPushNotificationConfigParams(MINIMAL_TASK.getId(), MINIMAL_TASK.getId())); - DeleteTaskPushNotificationConfigResponse deleteResponse = handler.deletePushNotificationConfig(deleteRequest); + DeleteTaskPushNotificationConfigResponse deleteResponse = + handler.deletePushNotificationConfig(deleteRequest, callContext); assertEquals("111", deleteResponse.getId()); assertNull(deleteResponse.getResult());