Skip to content

Commit d8fa0e6

Browse files
authored
fix: incorporate feedback from #138 (#146)
# Description Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [x] Follow the [`CONTRIBUTING` Guide](../CONTRIBUTING.md). - [x] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [x] Ensure the tests pass - [x] Appropriate READMEs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 33b8e77 commit d8fa0e6

File tree

7 files changed

+105
-117
lines changed

7 files changed

+105
-117
lines changed

sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.a2a.server.agentexecution;
22

33
import java.util.ArrayList;
4+
import java.util.Collections;
45
import java.util.List;
56
import java.util.UUID;
67
import java.util.stream.Collectors;
@@ -31,12 +32,12 @@ public RequestContext(MessageSendParams params, String taskId, String contextId,
3132

3233
// if the taskId and contextId were specified, they must match the params
3334
if (params != null) {
34-
if (taskId != null && ! params.message().getTaskId().equals(taskId)) {
35+
if (taskId != null && !taskId.equals(params.message().getTaskId())) {
3536
throw new InvalidParamsError("bad task id");
3637
} else {
3738
checkOrGenerateTaskId();
3839
}
39-
if (contextId != null && ! params.message().getContextId().equals(contextId)) {
40+
if (contextId != null && !contextId.equals(params.message().getContextId())) {
4041
throw new InvalidParamsError("bad context id");
4142
} else {
4243
checkOrGenerateContextId();
@@ -61,7 +62,7 @@ public Task getTask() {
6162
}
6263

6364
public List<Task> getRelatedTasks() {
64-
return relatedTasks;
65+
return Collections.unmodifiableList(relatedTasks);
6566
}
6667

6768
public Message getMessage() {

sdk-server-common/src/main/java/io/a2a/server/events/EnhancedRunnable.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import java.util.ArrayList;
44
import java.util.List;
5+
import java.util.concurrent.CopyOnWriteArrayList;
56

67
public abstract class EnhancedRunnable implements Runnable {
78
private volatile Throwable error;
8-
private final List<DoneCallback> doneCallbacks = new ArrayList<>();
9+
private final List<DoneCallback> doneCallbacks = new CopyOnWriteArrayList<>();
910

1011
public Throwable getError() {
1112
return error;
@@ -16,16 +17,12 @@ public void setError(Throwable error) {
1617
}
1718

1819
public void addDoneCallback(DoneCallback doneCallback) {
19-
synchronized (doneCallbacks) {
20-
doneCallbacks.add(doneCallback);
21-
}
20+
doneCallbacks.add(doneCallback);
2221
}
2322

2423
public void invokeDoneCallbacks() {
25-
synchronized (doneCallbacks) {
26-
for (DoneCallback doneCallback : doneCallbacks) {
27-
doneCallback.done(this);
28-
}
24+
for (DoneCallback doneCallback : doneCallbacks) {
25+
doneCallback.done(this);
2926
}
3027
}
3128

sdk-server-common/src/main/java/io/a2a/server/events/EventConsumer.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ public Flow.Publisher<Event> consumeAll() {
6767
completed = true;
6868
tube.complete();
6969
return;
70-
} catch (Exception e) {
71-
// Continue polling until there is a final event
72-
continue;
70+
} catch (Throwable t) {
71+
tube.fail(t);
72+
return;
7373
}
7474

7575
boolean isFinalEvent = false;

sdk-server-common/src/main/java/io/a2a/server/events/EventQueue.java

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import java.util.concurrent.BlockingQueue;
66
import java.util.concurrent.CopyOnWriteArrayList;
77
import java.util.concurrent.CountDownLatch;
8+
import java.util.concurrent.LinkedBlockingDeque;
9+
import java.util.concurrent.Semaphore;
810
import java.util.concurrent.TimeUnit;
911
import java.util.concurrent.atomic.AtomicBoolean;
1012

@@ -13,13 +15,17 @@
1315
import org.slf4j.Logger;
1416
import org.slf4j.LoggerFactory;
1517

16-
public abstract class EventQueue {
18+
public abstract class EventQueue implements AutoCloseable {
1719

1820
private static final Logger log = new TempLoggerWrapper(LoggerFactory.getLogger(EventQueue.class));
1921

22+
// TODO decide on a capacity
23+
private static final int queueSize = 1000;
24+
2025
private final EventQueue parent;
21-
// TODO decide on a capacity (or more appropriate queue data structures)
22-
private final BlockingQueue<Event> queue = new ArrayBlockingQueue<Event>(1000);
26+
27+
private final BlockingQueue<Event> queue = new LinkedBlockingDeque<>();
28+
private final Semaphore semaphore = new Semaphore(queueSize, true);
2329
private volatile boolean closed = false;
2430

2531

@@ -47,6 +53,12 @@ public void enqueueEvent(Event event) {
4753
return;
4854
}
4955
// Call toString() since for errors we don't really want the full stacktrace
56+
try {
57+
semaphore.acquire();
58+
} catch (InterruptedException e) {
59+
Thread.currentThread().interrupt();
60+
throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e);
61+
}
5062
queue.add(event);
5163
log.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this);
5264
}
@@ -64,6 +76,7 @@ public Event dequeueEvent(int waitMilliSeconds) throws EventQueueClosedException
6476
if (event != null) {
6577
// Call toString() since for errors we don't really want the full stacktrace
6678
log.debug("Dequeued event (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event);
79+
semaphore.release();
6780
}
6881
return event;
6982
}
@@ -72,6 +85,7 @@ public Event dequeueEvent(int waitMilliSeconds) throws EventQueueClosedException
7285
if (event != null) {
7386
// Call toString() since for errors we don't really want the full stacktrace
7487
log.debug("Dequeued event (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event);
88+
semaphore.release();
7589
}
7690
return event;
7791
} catch (InterruptedException e) {

sdk-server-common/src/main/java/io/a2a/server/events/InMemoryQueueManager.java

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33
import java.util.Collections;
44
import java.util.HashMap;
55
import java.util.Map;
6+
import java.util.concurrent.ConcurrentHashMap;
7+
import java.util.concurrent.ConcurrentMap;
68

79
import jakarta.enterprise.context.ApplicationScoped;
810

911
@ApplicationScoped
1012
public class InMemoryQueueManager implements QueueManager {
11-
private final Map<String, EventQueue> queues = Collections.synchronizedMap(new HashMap<>());
13+
private final ConcurrentMap<String, EventQueue> queues = new ConcurrentHashMap<>();
1214

1315
@Override
1416
public void add(String taskId, EventQueue queue) {
15-
synchronized (queues) {
16-
if (queues.containsKey(taskId)) {
17-
throw new TaskQueueExistsException();
18-
}
19-
queues.put(taskId, queue);
17+
EventQueue existing = queues.putIfAbsent(taskId, queue);
18+
if (existing != null) {
19+
throw new TaskQueueExistsException();
2020
}
2121
}
2222

@@ -27,36 +27,29 @@ public EventQueue get(String taskId) {
2727

2828
@Override
2929
public EventQueue tap(String taskId) {
30-
synchronized (taskId) {
31-
EventQueue queue = queues.get(taskId);
32-
if (queue == null) {
33-
return queue;
34-
}
35-
return queue.tap();
36-
}
30+
EventQueue queue = queues.get(taskId);
31+
return queue == null ? null : queue.tap();
3732
}
3833

3934
@Override
4035
public void close(String taskId) {
41-
synchronized (queues) {
42-
EventQueue existing = queues.remove(taskId);
43-
if (existing == null) {
44-
throw new NoTaskQueueException();
45-
}
36+
EventQueue existing = queues.remove(taskId);
37+
if (existing == null) {
38+
throw new NoTaskQueueException();
4639
}
4740
}
4841

4942
@Override
5043
public EventQueue createOrTap(String taskId) {
51-
synchronized (queues) {
52-
EventQueue queue = queues.get(taskId);
53-
if (queue != null) {
54-
return queue.tap();
55-
}
56-
queue = EventQueue.create();
57-
queues.put(taskId, queue);
58-
return queue;
44+
45+
EventQueue existing = queues.get(taskId);
46+
EventQueue newQueue = null;
47+
if (existing == null) {
48+
newQueue = EventQueue.create();
49+
// Make sure an existing queue has not been added in the meantime
50+
existing = queues.putIfAbsent(taskId, newQueue);
5951
}
52+
return existing == null ? newQueue : existing.tap();
6053
}
6154

6255
@Override

0 commit comments

Comments
 (0)