Skip to content

Commit 7e35e86

Browse files
authored
Merge pull request #106 from kabir/tck
Remove concurrent executor from TCK AgentExecutor
2 parents 18979cc + 0968c95 commit 7e35e86

File tree

3 files changed

+43
-111
lines changed

3 files changed

+43
-111
lines changed

src/main/java/io/a2a/server/events/EnhancedRunnable.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public void addDoneCallback(DoneCallback doneCallback) {
2121
}
2222
}
2323

24-
protected void invokeDoneCallbacks() {
24+
public void invokeDoneCallbacks() {
2525
synchronized (doneCallbacks) {
2626
for (DoneCallback doneCallback : doneCallbacks) {
2727
doneCallback.done(this);

src/main/java/io/a2a/server/requesthandlers/DefaultRequestHandler.java

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.List;
1111
import java.util.Map;
1212
import java.util.Objects;
13+
import java.util.concurrent.CompletableFuture;
1314
import java.util.concurrent.Executor;
1415
import java.util.concurrent.Executors;
1516
import java.util.concurrent.Flow;
@@ -32,7 +33,6 @@
3233
import io.a2a.server.tasks.ResultAggregator;
3334
import io.a2a.server.tasks.TaskManager;
3435
import io.a2a.server.tasks.TaskStore;
35-
import io.a2a.spec.Artifact;
3636
import io.a2a.spec.EventKind;
3737
import io.a2a.spec.InternalError;
3838
import io.a2a.spec.JSONRPCError;
@@ -47,7 +47,6 @@
4747
import io.a2a.spec.TaskQueryParams;
4848
import io.a2a.spec.UnsupportedOperationError;
4949
import io.a2a.util.TempLoggerWrapper;
50-
5150
import org.slf4j.Logger;
5251
import org.slf4j.LoggerFactory;
5352

@@ -62,8 +61,7 @@ public class DefaultRequestHandler implements RequestHandler {
6261
private final PushNotifier pushNotifier;
6362
private final Supplier<RequestContext.Builder> requestContextBuilder;
6463

65-
// TODO the value upstream is asyncio.Task. Trying a Runnable
66-
private final Map<String, EnhancedRunnable> runningAgents = Collections.synchronizedMap(new HashMap<>());
64+
private final Map<String, CompletableFuture<Void>> runningAgents = Collections.synchronizedMap(new HashMap<>());
6765

6866
private final Executor executor = Executors.newCachedThreadPool();
6967

@@ -134,7 +132,10 @@ public Task onCancelTask(TaskIdParams params) throws JSONRPCError {
134132
.build(),
135133
queue);
136134

137-
// TODO need to cancel the asyncio.Task looked up from runningAgents
135+
CompletableFuture<Void> cf = runningAgents.get(task.getId());
136+
if (cf != null) {
137+
cf.cancel(true);
138+
}
138139

139140
EventConsumer consumer = new EventConsumer(queue);
140141
EventKind type = resultAggregator.consumeAll(consumer);
@@ -203,9 +204,9 @@ public EventKind onMessageSend(MessageSendParams params) throws JSONRPCError {
203204
} finally {
204205
if (interrupted) {
205206
// TODO Make this async
206-
cleanupProducer(producerRunnable, taskId);
207+
cleanupProducer(taskId);
207208
} else {
208-
cleanupProducer(producerRunnable, taskId);
209+
cleanupProducer(taskId);
209210
}
210211
}
211212

@@ -244,7 +245,8 @@ public Flow.Publisher<StreamingEventKind> onMessageSendStream(MessageSendParams
244245
EnhancedRunnable producerRunnable = registerAndExecuteAgentAsync(taskId.get(), requestContext, queue);
245246

246247
EventConsumer consumer = new EventConsumer(queue);
247-
// TODO https://github.com/fjuma/a2a-java-sdk/issues/62 Add this callback
248+
249+
producerRunnable.addDoneCallback(consumer.createAgentRunnableDoneCallback());
248250

249251
try {
250252
Flow.Publisher<Event> results = resultAggregator.consumeAndEmit(consumer);
@@ -286,7 +288,7 @@ public Flow.Publisher<StreamingEventKind> onMessageSendStream(MessageSendParams
286288

287289
return convertingProcessor(eventPublisher, event -> (StreamingEventKind) event);
288290
} finally {
289-
cleanupProducer(producerRunnable, taskId.get());
291+
cleanupProducer(taskId.get());
290292
}
291293
}
292294

@@ -347,42 +349,40 @@ private boolean shouldAddPushInfo(MessageSendParams params) {
347349
return pushNotifier != null && params.configuration() != null && params.configuration().pushNotification() != null;
348350
}
349351

350-
private void runEventStream(RequestContext requestContext, EventQueue queue) throws JSONRPCError {
351-
agentExecutor.execute(requestContext, queue);
352-
// TODO this is in the Python implementation, but enabling it causes test hangs
353-
try {
354-
queueManager.signalPollingStarted(queue);
355-
} catch (InterruptedException e) {
356-
Thread.currentThread().interrupt();
357-
} finally {
358-
queue.close();
359-
}
360-
}
361-
362-
private EnhancedRunnable registerAndExecuteAgentAsync(String taskId, RequestContext requestContext, EventQueue eventQueue) {
352+
private EnhancedRunnable registerAndExecuteAgentAsync(String taskId, RequestContext requestContext, EventQueue queue) {
363353
EnhancedRunnable runnable = new EnhancedRunnable() {
364354
@Override
365355
public void run() {
356+
agentExecutor.execute(requestContext, queue);
366357
try {
367-
runEventStream(requestContext, eventQueue);
368-
} catch (Throwable throwable) {
369-
setError(throwable);
370-
} finally {
371-
invokeDoneCallbacks();
358+
queueManager.signalPollingStarted(queue);
359+
} catch (InterruptedException e) {
360+
Thread.currentThread().interrupt();
372361
}
373-
374362
}
375363
};
376-
runningAgents.put(taskId, runnable);
377-
executor.execute(runnable);
364+
365+
CompletableFuture<Void> cf = CompletableFuture.runAsync(runnable, executor)
366+
.whenComplete((v, err) -> {
367+
if (err != null) {
368+
runnable.setError(err);
369+
}
370+
queue.close();
371+
runnable.invokeDoneCallbacks();
372+
});
373+
runningAgents.put(taskId, cf);
378374
return runnable;
379375
}
380376

381-
private void cleanupProducer(Runnable producerRunnable, String taskId) {
377+
private void cleanupProducer(String taskId) {
382378
// TODO the Python implementation waits for the producerRunnable
383-
384-
queueManager.close(taskId);
385-
runningAgents.remove(taskId);
379+
CompletableFuture<Void> cf = runningAgents.get(taskId);
380+
if (cf != null) {
381+
cf.whenComplete((v, t) -> {
382+
queueManager.close(taskId);
383+
runningAgents.remove(taskId);
384+
});
385+
}
386386
}
387387

388388
}

tck/src/main/java/io/a2a/examples/helloworld/server/AgentExecutorProducer.java

Lines changed: 8 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
package io.a2a.examples.helloworld.server;
22

3-
import java.util.Map;
4-
import java.util.concurrent.CompletableFuture;
5-
import java.util.concurrent.ConcurrentHashMap;
6-
import java.util.concurrent.ExecutorService;
7-
import java.util.concurrent.Executors;
8-
import java.util.concurrent.TimeUnit;
9-
3+
import jakarta.annotation.PreDestroy;
104
import jakarta.enterprise.context.ApplicationScoped;
115
import jakarta.enterprise.inject.Produces;
12-
import jakarta.annotation.PreDestroy;
136

147
import io.a2a.server.agentexecution.AgentExecutor;
158
import io.a2a.server.agentexecution.RequestContext;
@@ -31,16 +24,6 @@ public AgentExecutor agentExecutor() {
3124
}
3225

3326
private static class FireAndForgetAgentExecutor implements AgentExecutor {
34-
// Dedicated thread pool for background task execution
35-
private final ExecutorService taskExecutor = Executors.newCachedThreadPool(r -> {
36-
Thread t = new Thread(r, "AgentTask-" + System.currentTimeMillis());
37-
t.setDaemon(true); // Don't prevent JVM shutdown
38-
return t;
39-
});
40-
41-
// Track running tasks for cancellation - store the future reference
42-
private final Map<String, CompletableFuture<Void>> runningTasks = new ConcurrentHashMap<>();
43-
4427
@Override
4528
public void execute(RequestContext context, EventQueue eventQueue) throws JSONRPCError {
4629
Task task = context.getTask();
@@ -68,12 +51,12 @@ public void execute(RequestContext context, EventQueue eventQueue) throws JSONRP
6851

6952
System.out.println("====> task set to WORKING, starting background execution");
7053

71-
// Fire and forget - start the task but don't wait for it
72-
CompletableFuture<Void> taskFuture = CompletableFuture
73-
.runAsync(() -> executeTaskInBackground(context, eventQueue), taskExecutor);
54+
// // Fire and forget - start the task but don't wait for it
55+
// CompletableFuture<Void> taskFuture = CompletableFuture
56+
// .runAsync(() -> executeTaskInBackground(context, eventQueue), taskExecutor);
7457

75-
// Store the future for potential cancellation
76-
runningTasks.put(context.getTaskId(), taskFuture);
58+
// // Store the future for potential cancellation
59+
// runningTasks.put(context.getTaskId(), taskFuture);
7760

7861
// Method returns immediately - task continues in background
7962
System.out.println("====> execute() method returning immediately, task running in background");
@@ -99,16 +82,6 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC
9982
throw new TaskNotCancelableError();
10083
}
10184

102-
// Cancel the CompletableFuture
103-
CompletableFuture<Void> taskFuture = runningTasks.get(task.getId());
104-
if (taskFuture != null) {
105-
boolean cancelled = taskFuture.cancel(true); // mayInterruptIfRunning = true
106-
System.out.println("====> cancellation attempted, success: " + cancelled);
107-
}
108-
109-
// Remove from running tasks and update status
110-
runningTasks.remove(task.getId());
111-
11285
eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder()
11386
.taskId(task.getId())
11487
.contextId(task.getContextId())
@@ -129,21 +102,9 @@ private void executeTaskInBackground(RequestContext context, EventQueue eventQue
129102
try {
130103
System.out.println("====> background execution started for task: " + taskId);
131104

132-
// Check if task was cancelled before we even started
133-
if (!runningTasks.containsKey(taskId)) {
134-
System.out.println("====> task was cancelled before background execution started");
135-
return;
136-
}
137-
138105
// Perform the actual work
139106
Object result = performActualWork(context);
140107

141-
// Check again if task was cancelled during execution
142-
if (!runningTasks.containsKey(taskId)) {
143-
System.out.println("====> task was cancelled during execution");
144-
return;
145-
}
146-
147108
// Task completed successfully
148109
eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder()
149110
.taskId(taskId)
@@ -159,33 +120,13 @@ private void executeTaskInBackground(RequestContext context, EventQueue eventQue
159120
System.out.println("====> background task was interrupted: " + taskId);
160121
Thread.currentThread().interrupt();
161122

162-
// Only send CANCELED event if task is still tracked (not already cancelled)
163-
if (runningTasks.containsKey(taskId)) {
164-
eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder()
165-
.taskId(taskId)
166-
.contextId(context.getContextId())
167-
.status(new TaskStatus(TaskState.CANCELED))
168-
.isFinal(true)
169-
.build());
170-
}
171-
172123
} catch (Exception e) {
173124
// Task failed
174125
System.err.println("====> background task failed: " + taskId);
175126
e.printStackTrace();
176127

177-
if (runningTasks.containsKey(taskId)) {
178-
eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder()
179-
.taskId(taskId)
180-
.contextId(context.getContextId())
181-
.status(new TaskStatus(TaskState.FAILED))
182-
.isFinal(true)
183-
.build());
184-
}
185-
186128
} finally {
187-
// Always clean up - remove from running tasks
188-
runningTasks.remove(taskId);
129+
// Always clean up
189130
System.out.println("====> background task cleanup completed: " + taskId);
190131
}
191132
}
@@ -220,15 +161,6 @@ private Object performActualWork(RequestContext context) throws InterruptedExcep
220161
@PreDestroy
221162
public void cleanup() {
222163
System.out.println("====> shutting down task executor");
223-
taskExecutor.shutdown();
224-
try {
225-
if (!taskExecutor.awaitTermination(5, TimeUnit.SECONDS)) {
226-
taskExecutor.shutdownNow();
227-
}
228-
} catch (InterruptedException e) {
229-
taskExecutor.shutdownNow();
230-
Thread.currentThread().interrupt();
231-
}
232-
}
164+
}
233165
}
234166
}

0 commit comments

Comments
 (0)