Skip to content

Commit 125e111

Browse files
authored
test: Added unit tests for RequestContext (#127)
This PR includes the following changes: 1. Added a unit test class for the `RequestContext` class, and all unit tests have been verified to pass. This is based on the unit test class from the `a2a` Python sdk: `tests/server/agent_execution/test_context.py`, with a few tests that are not applicable to the Java version removed. 2. Based on the unit test class, an issue in the `RequestContext` class was identified and fixed: the `relatedTasks` setting in the constructor was not working properly. Fixes #60 --------- Signed-off-by: Sun Yuhan <[email protected]> Co-authored-by: Sun Yuhan <[email protected]>
1 parent aede3ff commit 125e111

File tree

2 files changed

+249
-3
lines changed

2 files changed

+249
-3
lines changed

sdk-server-common/src/main/java/io/a2a/server/agentexecution/RequestContext.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,7 @@ public RequestContext(MessageSendParams params, String taskId, String contextId,
2727
this.taskId = taskId;
2828
this.contextId = contextId;
2929
this.task = task;
30-
if (relatedTasks == null) {
31-
this.relatedTasks = new ArrayList<>();
32-
}
30+
this.relatedTasks = relatedTasks == null ? new ArrayList<>() : relatedTasks;
3331

3432
// if the taskId and contextId were specified, they must match the params
3533
if (params != null) {
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
package io.a2a.server.agentexecution;
2+
3+
import io.a2a.spec.InvalidParamsError;
4+
import io.a2a.spec.Message;
5+
import io.a2a.spec.MessageSendParams;
6+
import io.a2a.spec.Task;
7+
import io.a2a.spec.TaskStatus;
8+
import io.a2a.spec.TaskState;
9+
import io.a2a.spec.TextPart;
10+
import org.junit.jupiter.api.Test;
11+
import org.mockito.MockedStatic;
12+
13+
import java.util.ArrayList;
14+
import java.util.List;
15+
import java.util.UUID;
16+
17+
import static org.junit.jupiter.api.Assertions.assertEquals;
18+
import static org.junit.jupiter.api.Assertions.assertNull;
19+
import static org.junit.jupiter.api.Assertions.assertThrows;
20+
import static org.junit.jupiter.api.Assertions.assertTrue;
21+
import static org.mockito.Mockito.mock;
22+
import static org.mockito.Mockito.mockStatic;
23+
24+
public class RequestContextTest {
25+
26+
@Test
27+
public void testInitWithoutParams() {
28+
RequestContext context = new RequestContext(null, null, null, null, null);
29+
assertNull(context.getMessage());
30+
assertNull(context.getTaskId());
31+
assertNull(context.getContextId());
32+
assertNull(context.getTask());
33+
assertTrue(context.getRelatedTasks().isEmpty());
34+
}
35+
36+
@Test
37+
public void testInitWithParamsNoIds() {
38+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build();
39+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
40+
41+
UUID taskId = UUID.fromString("00000000-0000-0000-0000-000000000001");
42+
UUID contextId = UUID.fromString("00000000-0000-0000-0000-000000000002");
43+
44+
try (MockedStatic<UUID> mockedUUID = mockStatic(UUID.class)) {
45+
mockedUUID.when(UUID::randomUUID)
46+
.thenReturn(taskId)
47+
.thenReturn(contextId);
48+
49+
RequestContext context = new RequestContext(mockParams, null, null, null, null);
50+
51+
assertEquals(mockParams.message(), context.getMessage());
52+
assertEquals(taskId.toString(), context.getTaskId());
53+
assertEquals(mockParams.message().getTaskId(), taskId.toString());
54+
assertEquals(contextId.toString(), context.getContextId());
55+
assertEquals(mockParams.message().getContextId(), contextId.toString());
56+
}
57+
}
58+
59+
@Test
60+
public void testInitWithTaskId() {
61+
String taskId = "task-123";
62+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).build();
63+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
64+
65+
RequestContext context = new RequestContext(mockParams, taskId, null, null, null);
66+
67+
assertEquals(taskId, context.getTaskId());
68+
assertEquals(taskId, mockParams.message().getTaskId());
69+
}
70+
71+
@Test
72+
public void testInitWithContextId() {
73+
String contextId = "context-456";
74+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(contextId).build();
75+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
76+
RequestContext context = new RequestContext(mockParams, null, contextId, null, null);
77+
78+
assertEquals(contextId, context.getContextId());
79+
assertEquals(contextId, mockParams.message().getContextId());
80+
}
81+
82+
@Test
83+
public void testInitWithBothIds() {
84+
String taskId = "task-123";
85+
String contextId = "context-456";
86+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(taskId).contextId(contextId).build();
87+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
88+
RequestContext context = new RequestContext(mockParams, taskId, contextId, null, null);
89+
90+
assertEquals(taskId, context.getTaskId());
91+
assertEquals(taskId, mockParams.message().getTaskId());
92+
assertEquals(contextId, context.getContextId());
93+
assertEquals(contextId, mockParams.message().getContextId());
94+
}
95+
96+
@Test
97+
public void testInitWithTask() {
98+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build();
99+
var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build();
100+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
101+
102+
RequestContext context = new RequestContext(mockParams, null, null, mockTask, null);
103+
104+
assertEquals(mockTask, context.getTask());
105+
}
106+
107+
@Test
108+
public void testGetUserInputNoParams() {
109+
RequestContext context = new RequestContext(null, null, null, null, null);
110+
assertEquals("", context.getUserInput(null));
111+
}
112+
113+
@Test
114+
public void testAttachRelatedTask() {
115+
var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build();
116+
117+
RequestContext context = new RequestContext(null, null, null, null, null);
118+
assertEquals(0, context.getRelatedTasks().size());
119+
120+
context.attachRelatedTask(mockTask);
121+
assertEquals(1, context.getRelatedTasks().size());
122+
assertEquals(mockTask, context.getRelatedTasks().get(0));
123+
124+
Task anotherTask = mock(Task.class);
125+
context.attachRelatedTask(anotherTask);
126+
assertEquals(2, context.getRelatedTasks().size());
127+
assertEquals(anotherTask, context.getRelatedTasks().get(1));
128+
}
129+
130+
@Test
131+
public void testCheckOrGenerateTaskIdWithExistingTaskId() {
132+
String existingId = "existing-task-id";
133+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId(existingId).build();
134+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
135+
136+
RequestContext context = new RequestContext(mockParams, null, null, null, null);
137+
138+
assertEquals(existingId, context.getTaskId());
139+
assertEquals(existingId, mockParams.message().getTaskId());
140+
}
141+
142+
@Test
143+
public void testCheckOrGenerateContextIdWithExistingContextId() {
144+
String existingId = "existing-context-id";
145+
146+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).contextId(existingId).build();
147+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
148+
149+
RequestContext context = new RequestContext(mockParams, null, null, null, null);
150+
151+
assertEquals(existingId, context.getContextId());
152+
assertEquals(existingId, mockParams.message().getContextId());
153+
}
154+
155+
@Test
156+
public void testInitRaisesErrorOnTaskIdMismatch() {
157+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").build();
158+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
159+
var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build();
160+
161+
InvalidParamsError error = assertThrows(InvalidParamsError.class, () ->
162+
new RequestContext(mockParams, "wrong-task-id", null, mockTask, null));
163+
164+
assertTrue(error.getMessage().contains("bad task id"));
165+
}
166+
167+
@Test
168+
public void testInitRaisesErrorOnContextIdMismatch() {
169+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").contextId("context-456").build();
170+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
171+
var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build();
172+
173+
InvalidParamsError error = assertThrows(InvalidParamsError.class, () ->
174+
new RequestContext(mockParams, mockTask.getId(), "wrong-context-id", mockTask, null));
175+
176+
assertTrue(error.getMessage().contains("bad context id"));
177+
}
178+
179+
@Test
180+
public void testWithRelatedTasksProvided() {
181+
var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build();
182+
183+
List<Task> relatedTasks = new ArrayList<>();
184+
relatedTasks.add(mockTask);
185+
relatedTasks.add(mock(Task.class));
186+
187+
RequestContext context = new RequestContext(null, null, null, null, relatedTasks);
188+
189+
assertEquals(relatedTasks, context.getRelatedTasks());
190+
assertEquals(2, context.getRelatedTasks().size());
191+
}
192+
193+
@Test
194+
public void testMessagePropertyWithoutParams() {
195+
RequestContext context = new RequestContext(null, null, null, null, null);
196+
assertNull(context.getMessage());
197+
}
198+
199+
@Test
200+
public void testMessagePropertyWithParams() {
201+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).build();
202+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
203+
204+
RequestContext context = new RequestContext(mockParams, null, null, null, null);
205+
assertEquals(mockParams.message(), context.getMessage());
206+
}
207+
208+
@Test
209+
public void testInitWithExistingIdsInMessage() {
210+
String existingTaskId = "existing-task-id";
211+
String existingContextId = "existing-context-id";
212+
213+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart("")))
214+
.taskId(existingTaskId).contextId(existingContextId).build();
215+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
216+
217+
RequestContext context = new RequestContext(mockParams, null, null, null, null);
218+
219+
assertEquals(existingTaskId, context.getTaskId());
220+
assertEquals(existingContextId, context.getContextId());
221+
}
222+
223+
@Test
224+
public void testInitWithTaskIdAndExistingTaskIdMatch() {
225+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").contextId("context-456").build();
226+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
227+
var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build();
228+
229+
230+
RequestContext context = new RequestContext(mockParams, mockTask.getId(), null, mockTask, null);
231+
232+
assertEquals(mockTask.getId(), context.getTaskId());
233+
assertEquals(mockTask, context.getTask());
234+
}
235+
236+
@Test
237+
public void testInitWithContextIdAndExistingContextIdMatch() {
238+
var mockMessage = new Message.Builder().role(Message.Role.USER).parts(List.of(new TextPart(""))).taskId("task-123").contextId("context-456").build();
239+
var mockParams = new MessageSendParams.Builder().message(mockMessage).build();
240+
var mockTask = new Task.Builder().id("task-123").contextId("context-456").status(new TaskStatus(TaskState.COMPLETED)).build();
241+
242+
243+
RequestContext context = new RequestContext(mockParams, mockTask.getId(), mockTask.getContextId(), mockTask, null);
244+
245+
assertEquals(mockTask.getContextId(), context.getContextId());
246+
assertEquals(mockTask, context.getTask());
247+
}
248+
}

0 commit comments

Comments
 (0)