Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,18 +1,43 @@
package io.a2a.server.grpc.quarkus;

import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;

import io.a2a.server.PublicAgentCard;
import io.a2a.grpc.handler.GrpcHandler;
import io.a2a.server.PublicAgentCard;
import io.a2a.server.requesthandlers.CallContextFactory;
import io.a2a.server.requesthandlers.RequestHandler;
import io.a2a.spec.AgentCard;
import io.quarkus.grpc.GrpcService;

@GrpcService
public class QuarkusGrpcHandler extends GrpcHandler {

private final AgentCard agentCard;
private final RequestHandler requestHandler;
private final Instance<CallContextFactory> callContextFactoryInstance;

@Inject
public QuarkusGrpcHandler(@PublicAgentCard AgentCard agentCard, RequestHandler requestHandler) {
super(agentCard, requestHandler);
public QuarkusGrpcHandler(@PublicAgentCard AgentCard agentCard,
RequestHandler requestHandler,
Instance<CallContextFactory> callContextFactoryInstance) {
this.agentCard = agentCard;
this.requestHandler = requestHandler;
this.callContextFactoryInstance = callContextFactoryInstance;
}

@Override
protected RequestHandler getRequestHandler() {
return requestHandler;
}

@Override
protected AgentCard getAgentCard() {
return agentCard;
}

@Override
protected CallContextFactory getCallContextFactory() {
return callContextFactoryInstance.isUnsatisfied() ? null : callContextFactoryInstance.get();
}
}
5 changes: 1 addition & 4 deletions reference/grpc/src/test/resources/application.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
quarkus.grpc.clients.a2a-service.host=localhost
quarkus.grpc.clients.a2a-service.port=9001

# The GrpcHandler @ApplicationScoped annotation is not compatible with Quarkus
quarkus.arc.exclude-types=io.a2a.grpc.handler.GrpcHandler
quarkus.grpc.clients.a2a-service.port=9001
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ public class AgentCardProducer {
@Produces
@PublicAgentCard
public AgentCard agentCard() {
String port = System.getProperty("test.agent.card.port", "8081");
return new AgentCard.Builder()
.name("test-card")
.description("A test agent card")
.url("http://localhost:8081")
.url("http://localhost:" + port)
.version("1.0")
.documentationUrl("http://example.com/docs")
.capabilities(new AgentCapabilities.Builder()
Expand Down
67 changes: 27 additions & 40 deletions transport/grpc/src/main/java/io/a2a/grpc/handler/GrpcHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@
import static io.a2a.grpc.utils.ProtoUtils.FromProto;
import static io.a2a.grpc.utils.ProtoUtils.ToProto;

import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.inject.Instance;
import jakarta.inject.Inject;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Flow;

import com.google.protobuf.Empty;
import io.a2a.grpc.A2AServiceGrpc;
import io.a2a.grpc.StreamResponse;
import io.a2a.server.PublicAgentCard;
import io.a2a.server.ServerCallContext;
import io.a2a.server.auth.UnauthenticatedUser;
import io.a2a.server.auth.User;
Expand Down Expand Up @@ -46,29 +43,14 @@
import io.grpc.Status;
import io.grpc.stub.StreamObserver;

import java.util.HashMap;
import java.util.Map;

@ApplicationScoped
public class GrpcHandler extends A2AServiceGrpc.A2AServiceImplBase {

private AgentCard agentCard;
private RequestHandler requestHandler;
public abstract class GrpcHandler extends A2AServiceGrpc.A2AServiceImplBase {

// Hook so testing can wait until streaming subscriptions are established.
// Without this we get intermittent failures
private static volatile Runnable streamingSubscribedRunnable;

@Inject
Instance<CallContextFactory> callContextFactory;

protected GrpcHandler() {
}

@Inject
public GrpcHandler(@PublicAgentCard AgentCard agentCard, RequestHandler requestHandler) {
this.agentCard = agentCard;
this.requestHandler = requestHandler;
public GrpcHandler() {
}

@Override
Expand All @@ -77,7 +59,7 @@ public void sendMessage(io.a2a.grpc.SendMessageRequest request,
try {
ServerCallContext context = createCallContext(responseObserver);
MessageSendParams params = FromProto.messageSendParams(request);
EventKind taskOrMessage = requestHandler.onMessageSend(params, context);
EventKind taskOrMessage = getRequestHandler().onMessageSend(params, context);
io.a2a.grpc.SendMessageResponse response = ToProto.taskOrMessage(taskOrMessage);
responseObserver.onNext(response);
responseObserver.onCompleted();
Expand All @@ -94,7 +76,7 @@ public void getTask(io.a2a.grpc.GetTaskRequest request,
try {
ServerCallContext context = createCallContext(responseObserver);
TaskQueryParams params = FromProto.taskQueryParams(request);
Task task = requestHandler.onGetTask(params, context);
Task task = getRequestHandler().onGetTask(params, context);
if (task != null) {
responseObserver.onNext(ToProto.task(task));
responseObserver.onCompleted();
Expand All @@ -114,7 +96,7 @@ public void cancelTask(io.a2a.grpc.CancelTaskRequest request,
try {
ServerCallContext context = createCallContext(responseObserver);
TaskIdParams params = FromProto.taskIdParams(request);
Task task = requestHandler.onCancelTask(params, context);
Task task = getRequestHandler().onCancelTask(params, context);
if (task != null) {
responseObserver.onNext(ToProto.task(task));
responseObserver.onCompleted();
Expand All @@ -131,15 +113,15 @@ public void cancelTask(io.a2a.grpc.CancelTaskRequest request,
@Override
public void createTaskPushNotificationConfig(io.a2a.grpc.CreateTaskPushNotificationConfigRequest request,
StreamObserver<io.a2a.grpc.TaskPushNotificationConfig> responseObserver) {
if (! agentCard.capabilities().pushNotifications()) {
if (!getAgentCard().capabilities().pushNotifications()) {
handleError(responseObserver, new PushNotificationNotSupportedError());
return;
}

try {
ServerCallContext context = createCallContext(responseObserver);
TaskPushNotificationConfig config = FromProto.taskPushNotificationConfig(request);
TaskPushNotificationConfig responseConfig = requestHandler.onSetTaskPushNotificationConfig(config, context);
TaskPushNotificationConfig responseConfig = getRequestHandler().onSetTaskPushNotificationConfig(config, context);
responseObserver.onNext(ToProto.taskPushNotificationConfig(responseConfig));
responseObserver.onCompleted();
} catch (JSONRPCError e) {
Expand All @@ -152,15 +134,15 @@ public void createTaskPushNotificationConfig(io.a2a.grpc.CreateTaskPushNotificat
@Override
public void getTaskPushNotificationConfig(io.a2a.grpc.GetTaskPushNotificationConfigRequest request,
StreamObserver<io.a2a.grpc.TaskPushNotificationConfig> responseObserver) {
if (! agentCard.capabilities().pushNotifications()) {
if (!getAgentCard().capabilities().pushNotifications()) {
handleError(responseObserver, new PushNotificationNotSupportedError());
return;
}

try {
ServerCallContext context = createCallContext(responseObserver);
GetTaskPushNotificationConfigParams params = FromProto.getTaskPushNotificationConfigParams(request);
TaskPushNotificationConfig config = requestHandler.onGetTaskPushNotificationConfig(params, context);
TaskPushNotificationConfig config = getRequestHandler().onGetTaskPushNotificationConfig(params, context);
responseObserver.onNext(ToProto.taskPushNotificationConfig(config));
responseObserver.onCompleted();
} catch (JSONRPCError e) {
Expand All @@ -173,15 +155,15 @@ public void getTaskPushNotificationConfig(io.a2a.grpc.GetTaskPushNotificationCon
@Override
public void listTaskPushNotificationConfig(io.a2a.grpc.ListTaskPushNotificationConfigRequest request,
StreamObserver<io.a2a.grpc.ListTaskPushNotificationConfigResponse> responseObserver) {
if (! agentCard.capabilities().pushNotifications()) {
if (!getAgentCard().capabilities().pushNotifications()) {
handleError(responseObserver, new PushNotificationNotSupportedError());
return;
}

try {
ServerCallContext context = createCallContext(responseObserver);
ListTaskPushNotificationConfigParams params = FromProto.listTaskPushNotificationConfigParams(request);
List<TaskPushNotificationConfig> configList = requestHandler.onListTaskPushNotificationConfig(params, context);
List<TaskPushNotificationConfig> configList = getRequestHandler().onListTaskPushNotificationConfig(params, context);
io.a2a.grpc.ListTaskPushNotificationConfigResponse.Builder responseBuilder =
io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder();
for (TaskPushNotificationConfig config : configList) {
Expand All @@ -199,15 +181,15 @@ public void listTaskPushNotificationConfig(io.a2a.grpc.ListTaskPushNotificationC
@Override
public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request,
StreamObserver<io.a2a.grpc.StreamResponse> responseObserver) {
if (! agentCard.capabilities().streaming()) {
if (!getAgentCard().capabilities().streaming()) {
handleError(responseObserver, new InvalidRequestError());
return;
}

try {
ServerCallContext context = createCallContext(responseObserver);
MessageSendParams params = FromProto.messageSendParams(request);
Flow.Publisher<StreamingEventKind> publisher = requestHandler.onMessageSendStream(params, context);
Flow.Publisher<StreamingEventKind> publisher = getRequestHandler().onMessageSendStream(params, context);
convertToStreamResponse(publisher, responseObserver);
} catch (JSONRPCError e) {
handleError(responseObserver, e);
Expand All @@ -219,15 +201,15 @@ public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request,
@Override
public void taskSubscription(io.a2a.grpc.TaskSubscriptionRequest request,
StreamObserver<io.a2a.grpc.StreamResponse> responseObserver) {
if (! agentCard.capabilities().streaming()) {
if (!getAgentCard().capabilities().streaming()) {
handleError(responseObserver, new InvalidRequestError());
return;
}

try {
ServerCallContext context = createCallContext(responseObserver);
TaskIdParams params = FromProto.taskIdParams(request);
Flow.Publisher<StreamingEventKind> publisher = requestHandler.onResubscribeToTask(params, context);
Flow.Publisher<StreamingEventKind> publisher = getRequestHandler().onResubscribeToTask(params, context);
convertToStreamResponse(publisher, responseObserver);
} catch (JSONRPCError e) {
handleError(responseObserver, e);
Expand Down Expand Up @@ -287,7 +269,7 @@ public void onComplete() {
public void getAgentCard(io.a2a.grpc.GetAgentCardRequest request,
StreamObserver<io.a2a.grpc.AgentCard> responseObserver) {
try {
responseObserver.onNext(ToProto.agentCard(agentCard));
responseObserver.onNext(ToProto.agentCard(getAgentCard()));
responseObserver.onCompleted();
} catch (Throwable t) {
handleInternalError(responseObserver, t);
Expand All @@ -297,15 +279,15 @@ public void getAgentCard(io.a2a.grpc.GetAgentCardRequest request,
@Override
public void deleteTaskPushNotificationConfig(io.a2a.grpc.DeleteTaskPushNotificationConfigRequest request,
StreamObserver<Empty> responseObserver) {
if (! agentCard.capabilities().pushNotifications()) {
if (!getAgentCard().capabilities().pushNotifications()) {
handleError(responseObserver, new PushNotificationNotSupportedError());
return;
}

try {
ServerCallContext context = createCallContext(responseObserver);
DeleteTaskPushNotificationConfigParams params = FromProto.deleteTaskPushNotificationConfigParams(request);
requestHandler.onDeleteTaskPushNotificationConfig(params, context);
getRequestHandler().onDeleteTaskPushNotificationConfig(params, context);
// void response
responseObserver.onNext(Empty.getDefaultInstance());
responseObserver.onCompleted();
Expand All @@ -317,7 +299,8 @@ public void deleteTaskPushNotificationConfig(io.a2a.grpc.DeleteTaskPushNotificat
}

private <V> ServerCallContext createCallContext(StreamObserver<V> responseObserver) {
if (callContextFactory == null || callContextFactory.isUnsatisfied()) {
CallContextFactory factory = getCallContextFactory();
if (factory == null) {
// Default implementation when no custom CallContextFactory is provided
// This handles both CDI injection scenarios and test scenarios where callContextFactory is null
User user = UnauthenticatedUser.INSTANCE;
Expand All @@ -335,7 +318,6 @@ private <V> ServerCallContext createCallContext(StreamObserver<V> responseObserv

return new ServerCallContext(user, state);
} else {
CallContextFactory factory = callContextFactory.get();
// TODO: CallContextFactory interface expects ServerCall + Metadata, but we only have StreamObserver
// This is another manifestation of the architectural limitation mentioned above
return factory.create(responseObserver); // Fall back to basic create() method for now
Expand Down Expand Up @@ -393,4 +375,9 @@ public static void setStreamingSubscribedRunnable(Runnable runnable) {
streamingSubscribedRunnable = runnable;
}

protected abstract RequestHandler getRequestHandler();

protected abstract AgentCard getAgentCard();

protected abstract CallContextFactory getCallContextFactory();
}
Loading