Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.a2a.server.agentexecution;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -31,12 +32,12 @@ public RequestContext(MessageSendParams params, String taskId, String contextId,

// if the taskId and contextId were specified, they must match the params
if (params != null) {
if (taskId != null && ! params.message().getTaskId().equals(taskId)) {
if (taskId != null && !taskId.equals(params.message().getTaskId())) {
throw new InvalidParamsError("bad task id");
} else {
checkOrGenerateTaskId();
}
if (contextId != null && ! params.message().getContextId().equals(contextId)) {
if (contextId != null && !contextId.equals(params.message().getContextId())) {
throw new InvalidParamsError("bad context id");
} else {
checkOrGenerateContextId();
Expand All @@ -61,7 +62,7 @@ public Task getTask() {
}

public List<Task> getRelatedTasks() {
return relatedTasks;
return Collections.unmodifiableList(relatedTasks);
}

public Message getMessage() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

public abstract class EnhancedRunnable implements Runnable {
private volatile Throwable error;
private final List<DoneCallback> doneCallbacks = new ArrayList<>();
private final List<DoneCallback> doneCallbacks = new CopyOnWriteArrayList<>();

public Throwable getError() {
return error;
Expand All @@ -16,16 +17,12 @@ public void setError(Throwable error) {
}

public void addDoneCallback(DoneCallback doneCallback) {
synchronized (doneCallbacks) {
doneCallbacks.add(doneCallback);
}
doneCallbacks.add(doneCallback);
}

public void invokeDoneCallbacks() {
synchronized (doneCallbacks) {
for (DoneCallback doneCallback : doneCallbacks) {
doneCallback.done(this);
}
for (DoneCallback doneCallback : doneCallbacks) {
doneCallback.done(this);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ public Flow.Publisher<Event> consumeAll() {
completed = true;
tube.complete();
return;
} catch (Exception e) {
// Continue polling until there is a final event
continue;
} catch (Throwable t) {
tube.fail(t);
return;
}

boolean isFinalEvent = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

Expand All @@ -13,13 +15,17 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class EventQueue {
public abstract class EventQueue implements AutoCloseable {

private static final Logger log = new TempLoggerWrapper(LoggerFactory.getLogger(EventQueue.class));

// TODO decide on a capacity
private static final int queueSize = 1000;

private final EventQueue parent;
// TODO decide on a capacity (or more appropriate queue data structures)
private final BlockingQueue<Event> queue = new ArrayBlockingQueue<Event>(1000);

private final BlockingQueue<Event> queue = new LinkedBlockingDeque<>();
private final Semaphore semaphore = new Semaphore(queueSize, true);
private volatile boolean closed = false;


Expand Down Expand Up @@ -47,6 +53,12 @@ public void enqueueEvent(Event event) {
return;
}
// Call toString() since for errors we don't really want the full stacktrace
try {
semaphore.acquire();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Unable to acquire the semaphore to enqueue the event", e);
}
queue.add(event);
log.debug("Enqueued event {} {}", event instanceof Throwable ? event.toString() : event, this);
}
Expand All @@ -64,6 +76,7 @@ public Event dequeueEvent(int waitMilliSeconds) throws EventQueueClosedException
if (event != null) {
// Call toString() since for errors we don't really want the full stacktrace
log.debug("Dequeued event (no wait) {} {}", this, event instanceof Throwable ? event.toString() : event);
semaphore.release();
}
return event;
}
Expand All @@ -72,6 +85,7 @@ public Event dequeueEvent(int waitMilliSeconds) throws EventQueueClosedException
if (event != null) {
// Call toString() since for errors we don't really want the full stacktrace
log.debug("Dequeued event (waiting) {} {}", this, event instanceof Throwable ? event.toString() : event);
semaphore.release();
}
return event;
} catch (InterruptedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import jakarta.enterprise.context.ApplicationScoped;

@ApplicationScoped
public class InMemoryQueueManager implements QueueManager {
private final Map<String, EventQueue> queues = Collections.synchronizedMap(new HashMap<>());
private final ConcurrentMap<String, EventQueue> queues = new ConcurrentHashMap<>();

@Override
public void add(String taskId, EventQueue queue) {
synchronized (queues) {
if (queues.containsKey(taskId)) {
throw new TaskQueueExistsException();
}
queues.put(taskId, queue);
EventQueue existing = queues.putIfAbsent(taskId, queue);
if (existing != null) {
throw new TaskQueueExistsException();
}
}

Expand All @@ -27,36 +27,29 @@ public EventQueue get(String taskId) {

@Override
public EventQueue tap(String taskId) {
synchronized (taskId) {
EventQueue queue = queues.get(taskId);
if (queue == null) {
return queue;
}
return queue.tap();
}
EventQueue queue = queues.get(taskId);
return queue == null ? null : queue.tap();
}

@Override
public void close(String taskId) {
synchronized (queues) {
EventQueue existing = queues.remove(taskId);
if (existing == null) {
throw new NoTaskQueueException();
}
EventQueue existing = queues.remove(taskId);
if (existing == null) {
throw new NoTaskQueueException();
}
}

@Override
public EventQueue createOrTap(String taskId) {
synchronized (queues) {
EventQueue queue = queues.get(taskId);
if (queue != null) {
return queue.tap();
}
queue = EventQueue.create();
queues.put(taskId, queue);
return queue;

EventQueue existing = queues.get(taskId);
EventQueue newQueue = null;
if (existing == null) {
newQueue = EventQueue.create();
// Make sure an existing queue has not been added in the meantime
existing = queues.putIfAbsent(taskId, newQueue);
}
return existing == null ? newQueue : existing.tap();
}

@Override
Expand Down
Loading