Skip to content

Commit 0e033fc

Browse files
committed
Tighten up TaskManager
1 parent 70d2973 commit 0e033fc

File tree

2 files changed

+31
-15
lines changed

2 files changed

+31
-15
lines changed

src/main/java/io/a2a/server/tasks/TaskManager.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,19 @@ public Task getTask() {
4444
if (taskId == null) {
4545
return null;
4646
}
47-
return taskStore.get(taskId);
47+
if (currentTask != null) {
48+
return currentTask;
49+
}
50+
currentTask = taskStore.get(taskId);
51+
return currentTask;
4852
}
4953

50-
public void saveTaskEvent(Task task) throws A2AServerException {
54+
Task saveTaskEvent(Task task) throws A2AServerException {
5155
checkIdsAndUpdateIfNecessary(task.getId(), task.getContextId());
52-
saveTask(task);
56+
return saveTask(task);
5357
}
5458

55-
public void saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException {
59+
Task saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException {
5660
checkIdsAndUpdateIfNecessary(event.getTaskId(), event.getContextId());
5761
Task task = ensureTask(event.getTaskId(), event.getContextId());
5862

@@ -67,10 +71,10 @@ public void saveTaskEvent(TaskStatusUpdateEvent event) throws A2AServerException
6771
}
6872

6973
task = builder.build();
70-
saveTask(task);
74+
return saveTask(task);
7175
}
7276

73-
public void saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException {
77+
Task saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerException {
7478
checkIdsAndUpdateIfNecessary(event.getTaskId(), event.getContextId());
7579
Task task = ensureTask(event.getTaskId(), event.getContextId());
7680

@@ -122,7 +126,7 @@ public void saveTaskEvent(TaskArtifactUpdateEvent event) throws A2AServerExcepti
122126
.artifacts(artifacts)
123127
.build();
124128

125-
saveTask(task);
129+
return saveTask(task);
126130
}
127131

128132
public Event process(Event event) throws A2AServerException {
@@ -145,7 +149,7 @@ public Task updateWithMessage(Message message, Task task) {
145149
task = new Task.Builder(task)
146150
.history(history)
147151
.build();
148-
currentTask = task;
152+
saveTask(task);
149153
return task;
150154
}
151155

@@ -164,7 +168,11 @@ private void checkIdsAndUpdateIfNecessary(String eventTaskId, String eventContex
164168
}
165169

166170
private Task ensureTask(String eventTaskId, String eventContextId) {
167-
Task task = taskStore.get(taskId);
171+
Task task = currentTask;
172+
if (task != null) {
173+
return task;
174+
}
175+
task = taskStore.get(taskId);
168176
if (task == null) {
169177
task = createTask(eventTaskId, eventContextId);
170178
saveTask(task);
@@ -182,11 +190,13 @@ private Task createTask(String taskId, String contextId) {
182190
.build();
183191
}
184192

185-
private void saveTask(Task task) {
193+
private Task saveTask(Task task) {
186194
taskStore.save(task);
187195
if (taskId == null) {
188196
taskId = task.getId();
189197
contextId = task.getContextId();
190198
}
199+
currentTask = task;
200+
return currentTask;
191201
}
192202
}

src/test/java/io/a2a/server/tasks/TaskManagerTest.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,10 @@ public void testGetTaskNonExistent() {
5757

5858
@Test
5959
public void testSaveTaskEventNewTask() throws A2AServerException {
60-
taskManager.saveTaskEvent(minimalTask);
60+
Task saved = taskManager.saveTaskEvent(minimalTask);
6161
Task retrieved = taskManager.getTask();
6262
assertSame(minimalTask, retrieved);
63+
assertSame(retrieved, saved);
6364
}
6465

6566
@Test
@@ -83,10 +84,11 @@ public void testSaveTaskEventStatusUpdate() throws A2AServerException {
8384
new HashMap<>());
8485

8586

86-
taskManager.saveTaskEvent(event);
87+
Task saved = taskManager.saveTaskEvent(event);
8788
Task updated = taskManager.getTask();
8889

8990
assertNotSame(initialTask, updated);
91+
assertSame(updated, saved);
9092

9193
assertEquals(initialTask.getId(), updated.getId());
9294
assertEquals(initialTask.getContextId(), updated.getContextId());
@@ -108,9 +110,11 @@ public void testSaveTaskEventArtifactUpdate() throws A2AServerException {
108110
.contextId(minimalTask.getContextId())
109111
.artifact(newArtifact)
110112
.build();
111-
taskManager.saveTaskEvent(event);
113+
Task saved = taskManager.saveTaskEvent(event);
112114

113115
Task updatedTask = taskManager.getTask();
116+
assertSame(updatedTask, saved);
117+
114118
assertNotSame(initialTask, updatedTask);
115119
assertEquals(initialTask.getId(), updatedTask.getId());
116120
assertEquals(initialTask.getContextId(), updatedTask.getContextId());
@@ -136,14 +140,15 @@ public void testEnsureTaskNonExistentForStatusUpdate() throws A2AServerException
136140
.isFinal(false)
137141
.build();
138142

139-
taskManagerWithoutId.saveTaskEvent(event);
143+
Task task = taskManagerWithoutId.saveTaskEvent(event);
140144
assertEquals(event.getTaskId(), taskManagerWithoutId.getTaskId());
141145
assertEquals(event.getContextId(), taskManagerWithoutId.getContextId());
142146

143147
Task newTask = taskManagerWithoutId.getTask();
144148
assertEquals(event.getTaskId(), newTask.getId());
145149
assertEquals(event.getContextId(), newTask.getContextId());
146150
assertEquals(TaskState.SUBMITTED, newTask.getStatus().state());
151+
assertSame(newTask, task);
147152
}
148153

149154
@Test
@@ -155,12 +160,13 @@ public void testSaveTaskEventNewTaskNoTaskId() throws A2AServerException {
155160
.status(new TaskStatus(TaskState.WORKING))
156161
.build();
157162

158-
taskManagerWithoutId.saveTaskEvent(task);
163+
Task saved = taskManagerWithoutId.saveTaskEvent(task);
159164
assertEquals(task.getId(), taskManagerWithoutId.getTaskId());
160165
assertEquals(task.getContextId(), taskManagerWithoutId.getContextId());
161166

162167
Task retrieved = taskManagerWithoutId.getTask();
163168
assertSame(task, retrieved);
169+
assertSame(retrieved, saved);
164170
}
165171

166172
@Test

0 commit comments

Comments
 (0)