Skip to content

Commit 1cf3660

Browse files
authored
chore: sync changes to existing tasks tests from Python (#299)
1 parent 655e1d3 commit 1cf3660

File tree

8 files changed

+1094
-14
lines changed

8 files changed

+1094
-14
lines changed

server-common/src/main/java/io/a2a/server/events/EventQueue.java

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,45 @@ public abstract class EventQueue implements AutoCloseable {
1717

1818
private static final Logger LOGGER = LoggerFactory.getLogger(EventQueue.class);
1919

20-
// TODO decide on a capacity
21-
private static final int queueSize = 1000;
20+
public static final int DEFAULT_QUEUE_SIZE = 1000;
2221

22+
private final int queueSize;
2323
private final BlockingQueue<Event> queue = new LinkedBlockingDeque<>();
24-
private final Semaphore semaphore = new Semaphore(queueSize, true);
24+
private final Semaphore semaphore;
2525
private volatile boolean closed = false;
2626

2727

2828

2929
protected EventQueue() {
30-
this(null);
30+
this(DEFAULT_QUEUE_SIZE);
31+
}
32+
33+
protected EventQueue(int queueSize) {
34+
if (queueSize <= 0) {
35+
throw new IllegalArgumentException("Queue size must be greater than 0");
36+
}
37+
this.queueSize = queueSize;
38+
this.semaphore = new Semaphore(queueSize, true);
39+
LOGGER.trace("Creating {} with queue size: {}", this, queueSize);
3140
}
3241

3342
protected EventQueue(EventQueue parent) {
43+
this(DEFAULT_QUEUE_SIZE);
3444
LOGGER.trace("Creating {}, parent: {}", this, parent);
3545
}
3646

3747
public static EventQueue create() {
38-
3948
return new MainQueue();
4049
}
4150

51+
public static EventQueue create(int queueSize) {
52+
return new MainQueue(queueSize);
53+
}
54+
55+
public int getQueueSize() {
56+
return queueSize;
57+
}
58+
4259
public abstract void awaitQueuePollerStart() throws InterruptedException ;
4360

4461
abstract void signalQueuePollerStarted();
@@ -132,6 +149,14 @@ static class MainQueue extends EventQueue {
132149
private final CountDownLatch pollingStartedLatch = new CountDownLatch(1);
133150
private final AtomicBoolean pollingStarted = new AtomicBoolean(false);
134151

152+
MainQueue() {
153+
super();
154+
}
155+
156+
MainQueue(int queueSize) {
157+
super(queueSize);
158+
}
159+
135160
EventQueue tap() {
136161
ChildQueue child = new ChildQueue(this);
137162
children.add(child);

server-common/src/main/java/io/a2a/server/tasks/TaskManager.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ Task saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException {
7676
builder.history(newHistory);
7777
}
7878

79+
// Handle metadata from the event
80+
if (event.getMetadata() != null) {
81+
builder.metadata(event.getMetadata());
82+
}
83+
7984
task = builder.build();
8085
return saveTask(task);
8186
}

server-common/src/main/java/io/a2a/server/tasks/TaskUpdater.java

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import java.util.List;
44
import java.util.Map;
55
import java.util.UUID;
6+
import java.util.concurrent.atomic.AtomicBoolean;
67

78
import io.a2a.server.agentexecution.RequestContext;
89
import io.a2a.server.events.EventQueue;
@@ -18,6 +19,8 @@ public class TaskUpdater {
1819
private final EventQueue eventQueue;
1920
private final String taskId;
2021
private final String contextId;
22+
private final AtomicBoolean terminalStateReached = new AtomicBoolean(false);
23+
private final Object stateLock = new Object();
2124

2225
public TaskUpdater(RequestContext context, EventQueue eventQueue) {
2326
this.eventQueue = eventQueue;
@@ -26,20 +29,41 @@ public TaskUpdater(RequestContext context, EventQueue eventQueue) {
2629
}
2730

2831
private void updateStatus(TaskState taskState) {
29-
updateStatus(taskState, null);
32+
updateStatus(taskState, null, taskState.isFinal());
3033
}
3134

32-
private void updateStatus(TaskState state, Message message) {
33-
TaskStatusUpdateEvent event = new TaskStatusUpdateEvent.Builder()
34-
.taskId(taskId)
35-
.contextId(contextId)
36-
.isFinal(state.isFinal())
37-
.status(new TaskStatus(state, message, null))
38-
.build();
39-
eventQueue.enqueueEvent(event);
35+
private void updateStatus(TaskState taskState, Message message) {
36+
updateStatus(taskState, message, taskState.isFinal());
37+
}
38+
39+
private void updateStatus(TaskState state, Message message, boolean isFinal) {
40+
synchronized (stateLock) {
41+
// Check if we're already in a terminal state
42+
if (terminalStateReached.get()) {
43+
throw new IllegalStateException("Cannot update task status - terminal state already reached");
44+
}
45+
46+
// If this is a final state, set the flag
47+
if (isFinal) {
48+
terminalStateReached.set(true);
49+
}
50+
51+
TaskStatusUpdateEvent event = new TaskStatusUpdateEvent.Builder()
52+
.taskId(taskId)
53+
.contextId(contextId)
54+
.isFinal(isFinal)
55+
.status(new TaskStatus(state, message, null))
56+
.build();
57+
eventQueue.enqueueEvent(event);
58+
}
4059
}
4160

4261
public void addArtifact(List<Part<?>> parts, String artifactId, String name, Map<String, Object> metadata) {
62+
addArtifact(parts, artifactId, name, metadata, null, null);
63+
}
64+
65+
public void addArtifact(List<Part<?>> parts, String artifactId, String name, Map<String, Object> metadata,
66+
Boolean append, Boolean lastChunk) {
4367
if (artifactId == null) {
4468
artifactId = UUID.randomUUID().toString();
4569
}
@@ -54,6 +78,8 @@ public void addArtifact(List<Part<?>> parts, String artifactId, String name, Map
5478
.metadata(metadata)
5579
.build()
5680
)
81+
.append(append)
82+
.lastChunk(lastChunk)
5783
.build();
5884
eventQueue.enqueueEvent(event);
5985
}
@@ -98,6 +124,46 @@ public void cancel(Message message) {
98124
updateStatus(TaskState.CANCELED, message);
99125
}
100126

127+
public void reject() {
128+
reject(null);
129+
}
130+
131+
public void reject(Message message) {
132+
updateStatus(TaskState.REJECTED, message);
133+
}
134+
135+
public void requiresInput() {
136+
requiresInput(null, false);
137+
}
138+
139+
public void requiresInput(Message message) {
140+
requiresInput(message, false);
141+
}
142+
143+
public void requiresInput(boolean isFinal) {
144+
requiresInput(null, isFinal);
145+
}
146+
147+
public void requiresInput(Message message, boolean isFinal) {
148+
updateStatus(TaskState.INPUT_REQUIRED, message, isFinal);
149+
}
150+
151+
public void requiresAuth() {
152+
requiresAuth(null, false);
153+
}
154+
155+
public void requiresAuth(Message message) {
156+
requiresAuth(message, false);
157+
}
158+
159+
public void requiresAuth(boolean isFinal) {
160+
requiresAuth(null, isFinal);
161+
}
162+
163+
public void requiresAuth(Message message, boolean isFinal) {
164+
updateStatus(TaskState.AUTH_REQUIRED, message, isFinal);
165+
}
166+
101167
public Message newAgentMessage(List<Part<?>> parts, Map<String, Object> metadata) {
102168
return new Message.Builder()
103169
.role(Message.Role.AGENT)

0 commit comments

Comments
 (0)