Skip to content

Commit 81b4a6b

Browse files
authored
feat: Refactor GrpcHandler to not use CDI and be more a utility (#219)
1 parent dd29e9b commit 81b4a6b

File tree

5 files changed

+115
-78
lines changed

5 files changed

+115
-78
lines changed
Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,43 @@
11
package io.a2a.server.grpc.quarkus;
22

3+
import jakarta.enterprise.inject.Instance;
34
import jakarta.inject.Inject;
45

5-
import io.a2a.server.PublicAgentCard;
66
import io.a2a.grpc.handler.GrpcHandler;
7+
import io.a2a.server.PublicAgentCard;
8+
import io.a2a.server.requesthandlers.CallContextFactory;
79
import io.a2a.server.requesthandlers.RequestHandler;
810
import io.a2a.spec.AgentCard;
911
import io.quarkus.grpc.GrpcService;
1012

1113
@GrpcService
1214
public class QuarkusGrpcHandler extends GrpcHandler {
1315

16+
private final AgentCard agentCard;
17+
private final RequestHandler requestHandler;
18+
private final Instance<CallContextFactory> callContextFactoryInstance;
19+
1420
@Inject
15-
public QuarkusGrpcHandler(@PublicAgentCard AgentCard agentCard, RequestHandler requestHandler) {
16-
super(agentCard, requestHandler);
21+
public QuarkusGrpcHandler(@PublicAgentCard AgentCard agentCard,
22+
RequestHandler requestHandler,
23+
Instance<CallContextFactory> callContextFactoryInstance) {
24+
this.agentCard = agentCard;
25+
this.requestHandler = requestHandler;
26+
this.callContextFactoryInstance = callContextFactoryInstance;
27+
}
28+
29+
@Override
30+
protected RequestHandler getRequestHandler() {
31+
return requestHandler;
32+
}
33+
34+
@Override
35+
protected AgentCard getAgentCard() {
36+
return agentCard;
37+
}
38+
39+
@Override
40+
protected CallContextFactory getCallContextFactory() {
41+
return callContextFactoryInstance.isUnsatisfied() ? null : callContextFactoryInstance.get();
1742
}
1843
}
Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,2 @@
11
quarkus.grpc.clients.a2a-service.host=localhost
2-
quarkus.grpc.clients.a2a-service.port=9001
3-
4-
# The GrpcHandler @ApplicationScoped annotation is not compatible with Quarkus
5-
quarkus.arc.exclude-types=io.a2a.grpc.handler.GrpcHandler
2+
quarkus.grpc.clients.a2a-service.port=9001

tests/server-common/src/test/java/io/a2a/server/apps/common/AgentCardProducer.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ public class AgentCardProducer {
1818
@Produces
1919
@PublicAgentCard
2020
public AgentCard agentCard() {
21+
String port = System.getProperty("test.agent.card.port", "8081");
2122
return new AgentCard.Builder()
2223
.name("test-card")
2324
.description("A test agent card")
24-
.url("http://localhost:8081")
25+
.url("http://localhost:" + port)
2526
.version("1.0")
2627
.documentationUrl("http://example.com/docs")
2728
.capabilities(new AgentCapabilities.Builder()

transport/grpc/src/main/java/io/a2a/grpc/handler/GrpcHandler.java

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@
33
import static io.a2a.grpc.utils.ProtoUtils.FromProto;
44
import static io.a2a.grpc.utils.ProtoUtils.ToProto;
55

6-
import jakarta.enterprise.context.ApplicationScoped;
7-
import jakarta.enterprise.inject.Instance;
8-
import jakarta.inject.Inject;
9-
6+
import java.util.HashMap;
107
import java.util.List;
8+
import java.util.Map;
119
import java.util.concurrent.CompletableFuture;
1210
import java.util.concurrent.Flow;
1311

1412
import com.google.protobuf.Empty;
1513
import io.a2a.grpc.A2AServiceGrpc;
1614
import io.a2a.grpc.StreamResponse;
17-
import io.a2a.server.PublicAgentCard;
1815
import io.a2a.server.ServerCallContext;
1916
import io.a2a.server.auth.UnauthenticatedUser;
2017
import io.a2a.server.auth.User;
@@ -46,29 +43,14 @@
4643
import io.grpc.Status;
4744
import io.grpc.stub.StreamObserver;
4845

49-
import java.util.HashMap;
50-
import java.util.Map;
51-
52-
@ApplicationScoped
53-
public class GrpcHandler extends A2AServiceGrpc.A2AServiceImplBase {
54-
55-
private AgentCard agentCard;
56-
private RequestHandler requestHandler;
46+
public abstract class GrpcHandler extends A2AServiceGrpc.A2AServiceImplBase {
5747

5848
// Hook so testing can wait until streaming subscriptions are established.
5949
// Without this we get intermittent failures
6050
private static volatile Runnable streamingSubscribedRunnable;
6151

62-
@Inject
63-
Instance<CallContextFactory> callContextFactory;
64-
65-
protected GrpcHandler() {
66-
}
6752

68-
@Inject
69-
public GrpcHandler(@PublicAgentCard AgentCard agentCard, RequestHandler requestHandler) {
70-
this.agentCard = agentCard;
71-
this.requestHandler = requestHandler;
53+
public GrpcHandler() {
7254
}
7355

7456
@Override
@@ -77,7 +59,7 @@ public void sendMessage(io.a2a.grpc.SendMessageRequest request,
7759
try {
7860
ServerCallContext context = createCallContext(responseObserver);
7961
MessageSendParams params = FromProto.messageSendParams(request);
80-
EventKind taskOrMessage = requestHandler.onMessageSend(params, context);
62+
EventKind taskOrMessage = getRequestHandler().onMessageSend(params, context);
8163
io.a2a.grpc.SendMessageResponse response = ToProto.taskOrMessage(taskOrMessage);
8264
responseObserver.onNext(response);
8365
responseObserver.onCompleted();
@@ -94,7 +76,7 @@ public void getTask(io.a2a.grpc.GetTaskRequest request,
9476
try {
9577
ServerCallContext context = createCallContext(responseObserver);
9678
TaskQueryParams params = FromProto.taskQueryParams(request);
97-
Task task = requestHandler.onGetTask(params, context);
79+
Task task = getRequestHandler().onGetTask(params, context);
9880
if (task != null) {
9981
responseObserver.onNext(ToProto.task(task));
10082
responseObserver.onCompleted();
@@ -114,7 +96,7 @@ public void cancelTask(io.a2a.grpc.CancelTaskRequest request,
11496
try {
11597
ServerCallContext context = createCallContext(responseObserver);
11698
TaskIdParams params = FromProto.taskIdParams(request);
117-
Task task = requestHandler.onCancelTask(params, context);
99+
Task task = getRequestHandler().onCancelTask(params, context);
118100
if (task != null) {
119101
responseObserver.onNext(ToProto.task(task));
120102
responseObserver.onCompleted();
@@ -131,15 +113,15 @@ public void cancelTask(io.a2a.grpc.CancelTaskRequest request,
131113
@Override
132114
public void createTaskPushNotificationConfig(io.a2a.grpc.CreateTaskPushNotificationConfigRequest request,
133115
StreamObserver<io.a2a.grpc.TaskPushNotificationConfig> responseObserver) {
134-
if (! agentCard.capabilities().pushNotifications()) {
116+
if (!getAgentCard().capabilities().pushNotifications()) {
135117
handleError(responseObserver, new PushNotificationNotSupportedError());
136118
return;
137119
}
138120

139121
try {
140122
ServerCallContext context = createCallContext(responseObserver);
141123
TaskPushNotificationConfig config = FromProto.taskPushNotificationConfig(request);
142-
TaskPushNotificationConfig responseConfig = requestHandler.onSetTaskPushNotificationConfig(config, context);
124+
TaskPushNotificationConfig responseConfig = getRequestHandler().onSetTaskPushNotificationConfig(config, context);
143125
responseObserver.onNext(ToProto.taskPushNotificationConfig(responseConfig));
144126
responseObserver.onCompleted();
145127
} catch (JSONRPCError e) {
@@ -152,15 +134,15 @@ public void createTaskPushNotificationConfig(io.a2a.grpc.CreateTaskPushNotificat
152134
@Override
153135
public void getTaskPushNotificationConfig(io.a2a.grpc.GetTaskPushNotificationConfigRequest request,
154136
StreamObserver<io.a2a.grpc.TaskPushNotificationConfig> responseObserver) {
155-
if (! agentCard.capabilities().pushNotifications()) {
137+
if (!getAgentCard().capabilities().pushNotifications()) {
156138
handleError(responseObserver, new PushNotificationNotSupportedError());
157139
return;
158140
}
159141

160142
try {
161143
ServerCallContext context = createCallContext(responseObserver);
162144
GetTaskPushNotificationConfigParams params = FromProto.getTaskPushNotificationConfigParams(request);
163-
TaskPushNotificationConfig config = requestHandler.onGetTaskPushNotificationConfig(params, context);
145+
TaskPushNotificationConfig config = getRequestHandler().onGetTaskPushNotificationConfig(params, context);
164146
responseObserver.onNext(ToProto.taskPushNotificationConfig(config));
165147
responseObserver.onCompleted();
166148
} catch (JSONRPCError e) {
@@ -173,15 +155,15 @@ public void getTaskPushNotificationConfig(io.a2a.grpc.GetTaskPushNotificationCon
173155
@Override
174156
public void listTaskPushNotificationConfig(io.a2a.grpc.ListTaskPushNotificationConfigRequest request,
175157
StreamObserver<io.a2a.grpc.ListTaskPushNotificationConfigResponse> responseObserver) {
176-
if (! agentCard.capabilities().pushNotifications()) {
158+
if (!getAgentCard().capabilities().pushNotifications()) {
177159
handleError(responseObserver, new PushNotificationNotSupportedError());
178160
return;
179161
}
180162

181163
try {
182164
ServerCallContext context = createCallContext(responseObserver);
183165
ListTaskPushNotificationConfigParams params = FromProto.listTaskPushNotificationConfigParams(request);
184-
List<TaskPushNotificationConfig> configList = requestHandler.onListTaskPushNotificationConfig(params, context);
166+
List<TaskPushNotificationConfig> configList = getRequestHandler().onListTaskPushNotificationConfig(params, context);
185167
io.a2a.grpc.ListTaskPushNotificationConfigResponse.Builder responseBuilder =
186168
io.a2a.grpc.ListTaskPushNotificationConfigResponse.newBuilder();
187169
for (TaskPushNotificationConfig config : configList) {
@@ -199,15 +181,15 @@ public void listTaskPushNotificationConfig(io.a2a.grpc.ListTaskPushNotificationC
199181
@Override
200182
public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request,
201183
StreamObserver<io.a2a.grpc.StreamResponse> responseObserver) {
202-
if (! agentCard.capabilities().streaming()) {
184+
if (!getAgentCard().capabilities().streaming()) {
203185
handleError(responseObserver, new InvalidRequestError());
204186
return;
205187
}
206188

207189
try {
208190
ServerCallContext context = createCallContext(responseObserver);
209191
MessageSendParams params = FromProto.messageSendParams(request);
210-
Flow.Publisher<StreamingEventKind> publisher = requestHandler.onMessageSendStream(params, context);
192+
Flow.Publisher<StreamingEventKind> publisher = getRequestHandler().onMessageSendStream(params, context);
211193
convertToStreamResponse(publisher, responseObserver);
212194
} catch (JSONRPCError e) {
213195
handleError(responseObserver, e);
@@ -219,15 +201,15 @@ public void sendStreamingMessage(io.a2a.grpc.SendMessageRequest request,
219201
@Override
220202
public void taskSubscription(io.a2a.grpc.TaskSubscriptionRequest request,
221203
StreamObserver<io.a2a.grpc.StreamResponse> responseObserver) {
222-
if (! agentCard.capabilities().streaming()) {
204+
if (!getAgentCard().capabilities().streaming()) {
223205
handleError(responseObserver, new InvalidRequestError());
224206
return;
225207
}
226208

227209
try {
228210
ServerCallContext context = createCallContext(responseObserver);
229211
TaskIdParams params = FromProto.taskIdParams(request);
230-
Flow.Publisher<StreamingEventKind> publisher = requestHandler.onResubscribeToTask(params, context);
212+
Flow.Publisher<StreamingEventKind> publisher = getRequestHandler().onResubscribeToTask(params, context);
231213
convertToStreamResponse(publisher, responseObserver);
232214
} catch (JSONRPCError e) {
233215
handleError(responseObserver, e);
@@ -287,7 +269,7 @@ public void onComplete() {
287269
public void getAgentCard(io.a2a.grpc.GetAgentCardRequest request,
288270
StreamObserver<io.a2a.grpc.AgentCard> responseObserver) {
289271
try {
290-
responseObserver.onNext(ToProto.agentCard(agentCard));
272+
responseObserver.onNext(ToProto.agentCard(getAgentCard()));
291273
responseObserver.onCompleted();
292274
} catch (Throwable t) {
293275
handleInternalError(responseObserver, t);
@@ -297,15 +279,15 @@ public void getAgentCard(io.a2a.grpc.GetAgentCardRequest request,
297279
@Override
298280
public void deleteTaskPushNotificationConfig(io.a2a.grpc.DeleteTaskPushNotificationConfigRequest request,
299281
StreamObserver<Empty> responseObserver) {
300-
if (! agentCard.capabilities().pushNotifications()) {
282+
if (!getAgentCard().capabilities().pushNotifications()) {
301283
handleError(responseObserver, new PushNotificationNotSupportedError());
302284
return;
303285
}
304286

305287
try {
306288
ServerCallContext context = createCallContext(responseObserver);
307289
DeleteTaskPushNotificationConfigParams params = FromProto.deleteTaskPushNotificationConfigParams(request);
308-
requestHandler.onDeleteTaskPushNotificationConfig(params, context);
290+
getRequestHandler().onDeleteTaskPushNotificationConfig(params, context);
309291
// void response
310292
responseObserver.onNext(Empty.getDefaultInstance());
311293
responseObserver.onCompleted();
@@ -317,7 +299,8 @@ public void deleteTaskPushNotificationConfig(io.a2a.grpc.DeleteTaskPushNotificat
317299
}
318300

319301
private <V> ServerCallContext createCallContext(StreamObserver<V> responseObserver) {
320-
if (callContextFactory == null || callContextFactory.isUnsatisfied()) {
302+
CallContextFactory factory = getCallContextFactory();
303+
if (factory == null) {
321304
// Default implementation when no custom CallContextFactory is provided
322305
// This handles both CDI injection scenarios and test scenarios where callContextFactory is null
323306
User user = UnauthenticatedUser.INSTANCE;
@@ -335,7 +318,6 @@ private <V> ServerCallContext createCallContext(StreamObserver<V> responseObserv
335318

336319
return new ServerCallContext(user, state);
337320
} else {
338-
CallContextFactory factory = callContextFactory.get();
339321
// TODO: CallContextFactory interface expects ServerCall + Metadata, but we only have StreamObserver
340322
// This is another manifestation of the architectural limitation mentioned above
341323
return factory.create(responseObserver); // Fall back to basic create() method for now
@@ -393,4 +375,9 @@ public static void setStreamingSubscribedRunnable(Runnable runnable) {
393375
streamingSubscribedRunnable = runnable;
394376
}
395377

378+
protected abstract RequestHandler getRequestHandler();
379+
380+
protected abstract AgentCard getAgentCard();
381+
382+
protected abstract CallContextFactory getCallContextFactory();
396383
}

0 commit comments

Comments
 (0)