Skip to content

Commit 0f09faf

Browse files
authored
Merge pull request #109 from kabir/cleanup
Cleanup the AgentExecutors and use TaskUpdater where possible
2 parents 6edaabf + f7a8acb commit 0f09faf

File tree

6 files changed

+67
-144
lines changed

6 files changed

+67
-144
lines changed

src/main/java/io/a2a/server/tasks/TaskUpdater.java

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import java.util.Map;
55
import java.util.UUID;
66

7+
import io.a2a.server.agentexecution.RequestContext;
78
import io.a2a.server.events.EventQueue;
89
import io.a2a.spec.Artifact;
910
import io.a2a.spec.Message;
@@ -18,26 +19,21 @@ public class TaskUpdater {
1819
private final String taskId;
1920
private final String contextId;
2021

21-
public TaskUpdater(EventQueue eventQueue, String taskId, String contextId) {
22+
public TaskUpdater(RequestContext context, EventQueue eventQueue) {
2223
this.eventQueue = eventQueue;
23-
this.taskId = taskId;
24-
this.contextId = contextId;
24+
this.taskId = context.getTaskId();
25+
this.contextId = context.getContextId();
2526
}
2627

27-
public void updateStatus(TaskState taskState) {
28+
private void updateStatus(TaskState taskState) {
2829
updateStatus(taskState, null);
2930
}
3031

31-
public void updateStatus(TaskState state, Message message) {
32-
updateStatus(state, message, false);
33-
}
34-
35-
public void updateStatus(TaskState state, Message message, boolean isFinal) {
36-
32+
private void updateStatus(TaskState state, Message message) {
3733
TaskStatusUpdateEvent event = new TaskStatusUpdateEvent.Builder()
3834
.taskId(taskId)
3935
.contextId(contextId)
40-
.isFinal(isFinal)
36+
.isFinal(state.isFinal())
4137
.status(new TaskStatus(state, message, null))
4238
.build();
4339
eventQueue.enqueueEvent(event);
@@ -67,15 +63,15 @@ public void complete() {
6763
}
6864

6965
public void complete(Message message) {
70-
updateStatus(TaskState.COMPLETED, message, true);
66+
updateStatus(TaskState.COMPLETED, message);
7167
}
7268

73-
public void failed() {
74-
failed(null);
69+
public void fail() {
70+
fail(null);
7571
}
7672

77-
public void failed(Message message) {
78-
updateStatus(TaskState.FAILED, message, true);
73+
public void fail(Message message) {
74+
updateStatus(TaskState.FAILED, message);
7975
}
8076

8177
public void submit() {
@@ -94,6 +90,13 @@ public void startWork(Message message) {
9490
updateStatus(TaskState.WORKING, message);
9591
}
9692

93+
public void cancel() {
94+
cancel(null);
95+
}
96+
97+
public void cancel(Message message) {
98+
updateStatus(TaskState.CANCELED, message);
99+
}
97100

98101
public Message newAgentMessage(List<Part<?>> parts, Map<String, Object> metadata) {
99102
return new Message.Builder()

src/main/java/io/a2a/spec/TaskState.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,33 @@ public enum TaskState {
1111
WORKING("working"),
1212
INPUT_REQUIRED("input-required"),
1313
AUTH_REQUIRED("auth-required"),
14-
COMPLETED("completed"),
15-
CANCELED("canceled"),
16-
FAILED("failed"),
17-
REJECTED("rejected"),
18-
UNKNOWN("unknown");
14+
COMPLETED("completed", true),
15+
CANCELED("canceled", true),
16+
FAILED("failed", true),
17+
REJECTED("rejected", true),
18+
UNKNOWN("unknown", true);
1919

2020
private final String state;
21+
private final boolean isFinal;
2122

2223
TaskState(String state) {
24+
this(state, false);
25+
}
26+
27+
TaskState(String state, boolean isFinal) {
2328
this.state = state;
29+
this.isFinal = isFinal;
2430
}
2531

2632
@JsonValue
2733
public String asString() {
2834
return state;
2935
}
3036

37+
public boolean isFinal(){
38+
return isFinal;
39+
}
40+
3141
@JsonCreator
3242
public static TaskState fromString(String state) {
3343
switch (state) {

src/test/java/io/a2a/server/apps/AgentExecutorProducer.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import io.a2a.server.agentexecution.AgentExecutor;
77
import io.a2a.server.agentexecution.RequestContext;
88
import io.a2a.server.events.EventQueue;
9+
import io.a2a.server.tasks.TaskUpdater;
910
import io.a2a.spec.JSONRPCError;
1011
import io.a2a.spec.Task;
1112
import io.a2a.spec.TaskState;
@@ -29,12 +30,8 @@ public void execute(RequestContext context, EventQueue eventQueue) throws JSONRP
2930
@Override
3031
public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError {
3132
if (context.getTask().getId().equals("cancel-task-123")) {
32-
Task task = context.getTask();
33-
Task updated = new Task.Builder(task)
34-
.status(new TaskStatus(TaskState.CANCELED))
35-
.build();
36-
37-
eventQueue.enqueueEvent(updated);
33+
TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue);
34+
taskUpdater.cancel();
3835
} else if (context.getTask().getId().equals("cancel-task-not-supported-123")) {
3936
throw new UnsupportedOperationError();
4037
}

src/test/java/io/a2a/server/requesthandlers/JSONRPCHandlerTest.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import io.a2a.server.tasks.PushNotifier;
3636
import io.a2a.server.tasks.ResultAggregator;
3737
import io.a2a.server.tasks.TaskStore;
38+
import io.a2a.server.tasks.TaskUpdater;
3839
import io.a2a.spec.AgentCapabilities;
3940
import io.a2a.spec.AgentCard;
4041
import io.a2a.spec.Artifact;
@@ -166,11 +167,8 @@ public void testOnCancelTaskSuccess() throws Exception {
166167
// Looking at the Python implementation, they typically use AgentExecutors that
167168
// don't support cancellation. So my theory is the Agent updates the task to the CANCEL status
168169
Task task = context.getTask();
169-
Task updated = new Task.Builder(task)
170-
.status(new TaskStatus(TaskState.CANCELED))
171-
.build();
172-
173-
eventQueue.enqueueEvent(updated);
170+
TaskUpdater taskUpdater = new TaskUpdater(context, eventQueue);
171+
taskUpdater.cancel();
174172
};
175173

176174
CancelTaskRequest request = new CancelTaskRequest("111", new TaskIdParams(MINIMAL_TASK.getId()));

src/test/java/io/a2a/server/tasks/TaskUpdaterTest.java

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import static org.junit.jupiter.api.Assertions.assertNotNull;
77
import static org.junit.jupiter.api.Assertions.assertNull;
88
import static org.junit.jupiter.api.Assertions.assertSame;
9-
import static org.junit.jupiter.api.Assertions.assertTrue;
109

1110
import java.util.List;
1211
import java.util.Map;
1312

13+
import io.a2a.server.agentexecution.RequestContext;
1414
import io.a2a.server.events.Event;
1515
import io.a2a.server.events.EventQueue;
1616
import io.a2a.spec.Message;
@@ -20,9 +20,7 @@
2020
import io.a2a.spec.TaskStatusUpdateEvent;
2121
import io.a2a.spec.TextPart;
2222
import org.junit.jupiter.api.BeforeEach;
23-
import org.junit.jupiter.api.Disabled;
2423
import org.junit.jupiter.api.Test;
25-
import org.mockito.Mockito;
2624

2725
public class TaskUpdaterTest {
2826
public static final String TEST_TASK_ID = "test-task-id";
@@ -45,31 +43,11 @@ public class TaskUpdaterTest {
4543
@BeforeEach
4644
public void init() {
4745
eventQueue = EventQueue.create();
48-
taskUpdater = new TaskUpdater(eventQueue, TEST_TASK_ID, TEST_TASK_CONTEXT_ID);
49-
}
50-
51-
//@Test
52-
//public void testInit() {
53-
// // Python has a unit test testing that the constructor works. Not really relevant
54-
//}
55-
56-
@Test
57-
public void testUpdateStatusWithoutMessage() throws Exception {
58-
taskUpdater.updateStatus(TaskState.WORKING);
59-
checkTaskStatusUpdateEventOnQueue(false, TaskState.WORKING, null);
60-
61-
}
62-
63-
@Test
64-
public void testUpdateStatusWithMessage() throws Exception {
65-
taskUpdater.updateStatus(TaskState.WORKING, SAMPLE_MESSAGE);
66-
checkTaskStatusUpdateEventOnQueue(false, TaskState.WORKING, SAMPLE_MESSAGE);
67-
}
68-
69-
@Test
70-
public void testUpdateStatusFinal() throws Exception {
71-
taskUpdater.updateStatus(TaskState.COMPLETED, null, true);
72-
checkTaskStatusUpdateEventOnQueue(true, TaskState.COMPLETED, null);
46+
RequestContext context = new RequestContext.Builder()
47+
.setTaskId(TEST_TASK_ID)
48+
.setContextId(TEST_TASK_CONTEXT_ID)
49+
.build();
50+
taskUpdater = new TaskUpdater(context, eventQueue);
7351
}
7452

7553
@Test
@@ -128,16 +106,28 @@ public void testStartWorkWithMessage() throws Exception {
128106

129107
@Test
130108
public void testFailedWithoutMessage() throws Exception {
131-
taskUpdater.failed();
109+
taskUpdater.fail();
132110
checkTaskStatusUpdateEventOnQueue(true, TaskState.FAILED, null);
133111
}
134112

135113
@Test
136114
public void testFailedWithMessage() throws Exception {
137-
taskUpdater.failed(SAMPLE_MESSAGE);
115+
taskUpdater.fail(SAMPLE_MESSAGE);
138116
checkTaskStatusUpdateEventOnQueue(true, TaskState.FAILED, SAMPLE_MESSAGE);
139117
}
140118

119+
@Test
120+
public void testCanceledWithoutMessage() throws Exception {
121+
taskUpdater.cancel();
122+
checkTaskStatusUpdateEventOnQueue(true, TaskState.CANCELED, null);
123+
}
124+
125+
@Test
126+
public void testCanceledWithMessage() throws Exception {
127+
taskUpdater.cancel(SAMPLE_MESSAGE);
128+
checkTaskStatusUpdateEventOnQueue(true, TaskState.CANCELED, SAMPLE_MESSAGE);
129+
}
130+
141131
@Test
142132
public void testNewAgentMessage() throws Exception {
143133
Message message = taskUpdater.newAgentMessage(SAMPLE_PARTS, null);

tck/src/main/java/io/a2a/tck/server/AgentExecutorProducer.java

Lines changed: 7 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import io.a2a.server.agentexecution.AgentExecutor;
88
import io.a2a.server.agentexecution.RequestContext;
99
import io.a2a.server.events.EventQueue;
10+
import io.a2a.server.tasks.TaskUpdater;
1011
import io.a2a.spec.JSONRPCError;
1112
import io.a2a.spec.Task;
1213
import io.a2a.spec.TaskNotCancelableError;
@@ -42,22 +43,12 @@ public void execute(RequestContext context, EventQueue eventQueue) throws JSONRP
4243
eventQueue.enqueueEvent(task);
4344
}
4445

46+
TaskUpdater updater = new TaskUpdater(context, eventQueue);
47+
4548
// Immediately set to WORKING state
46-
eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder()
47-
.taskId(context.getTaskId())
48-
.contextId(context.getContextId())
49-
.status(new TaskStatus(TaskState.WORKING))
50-
.build());
51-
49+
updater.startWork();
5250
System.out.println("====> task set to WORKING, starting background execution");
5351

54-
// // Fire and forget - start the task but don't wait for it
55-
// CompletableFuture<Void> taskFuture = CompletableFuture
56-
// .runAsync(() -> executeTaskInBackground(context, eventQueue), taskExecutor);
57-
58-
// // Store the future for potential cancellation
59-
// runningTasks.put(context.getTaskId(), taskFuture);
60-
6152
// Method returns immediately - task continues in background
6253
System.out.println("====> execute() method returning immediately, task running in background");
6354
}
@@ -66,12 +57,7 @@ public void execute(RequestContext context, EventQueue eventQueue) throws JSONRP
6657
public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPCError {
6758
System.out.println("====> task cancel request received");
6859
Task task = context.getTask();
69-
70-
if (task == null) {
71-
System.out.println("====> task not found");
72-
throw new TaskNotFoundError();
73-
}
74-
60+
7561
if (task.getStatus().state() == TaskState.CANCELED) {
7662
System.out.println("====> task already canceled");
7763
throw new TaskNotCancelableError();
@@ -82,6 +68,8 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC
8268
throw new TaskNotCancelableError();
8369
}
8470

71+
TaskUpdater updater = new TaskUpdater(context, eventQueue);
72+
updater.cancel();
8573
eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder()
8674
.taskId(task.getId())
8775
.contextId(task.getContextId())
@@ -92,69 +80,6 @@ public void cancel(RequestContext context, EventQueue eventQueue) throws JSONRPC
9280
System.out.println("====> task canceled");
9381
}
9482

95-
/**
96-
* This method runs completely in the background.
97-
* The main execute() method has already returned.
98-
*/
99-
private void executeTaskInBackground(RequestContext context, EventQueue eventQueue) {
100-
String taskId = context.getTaskId();
101-
102-
try {
103-
System.out.println("====> background execution started for task: " + taskId);
104-
105-
// Perform the actual work
106-
Object result = performActualWork(context);
107-
108-
// Task completed successfully
109-
eventQueue.enqueueEvent(new TaskStatusUpdateEvent.Builder()
110-
.taskId(taskId)
111-
.contextId(context.getContextId())
112-
.status(new TaskStatus(TaskState.COMPLETED))
113-
.isFinal(true)
114-
.build());
115-
116-
System.out.println("====> background task completed successfully: " + taskId);
117-
118-
} catch (InterruptedException e) {
119-
// Task was interrupted (cancelled)
120-
System.out.println("====> background task was interrupted: " + taskId);
121-
Thread.currentThread().interrupt();
122-
123-
} catch (Exception e) {
124-
// Task failed
125-
System.err.println("====> background task failed: " + taskId);
126-
e.printStackTrace();
127-
128-
} finally {
129-
// Always clean up
130-
System.out.println("====> background task cleanup completed: " + taskId);
131-
}
132-
}
133-
134-
/**
135-
* This method represents the actual work that needs to be done.
136-
* Replace this with your real business logic.
137-
*/
138-
private Object performActualWork(RequestContext context) throws InterruptedException {
139-
140-
141-
System.out.println("====> starting actual work for task: " + context.getTaskId());
142-
143-
// Simulate work that can be interrupted
144-
for (int i = 0; i < 10; i++) {
145-
// Check for interruption regularly during long-running work
146-
if (Thread.currentThread().isInterrupted()) {
147-
throw new InterruptedException("Task was cancelled during execution");
148-
}
149-
150-
Thread.sleep(200); // Simulate work chunks
151-
System.out.println("====> work progress for task " + context.getTaskId() + ": " + ((i + 1) * 10) + "%");
152-
}
153-
154-
System.out.println("====> finished actual work for task: " + context.getTaskId());
155-
return "Task completed successfully";
156-
}
157-
15883
/**
15984
* Cleanup method for proper resource management
16085
*/

0 commit comments

Comments
 (0)