Skip to content

Commit 7f8a6ba

Browse files
committed
feat: raise error for tasks in terminal states
1 parent b88ca85 commit 7f8a6ba

File tree

2 files changed

+174
-1
lines changed

2 files changed

+174
-1
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from a2a.types import (
3030
GetTaskPushNotificationConfigParams,
3131
InternalError,
32+
InvalidParamsError,
3233
Message,
3334
MessageSendConfiguration,
3435
MessageSendParams,
@@ -38,6 +39,7 @@
3839
TaskNotFoundError,
3940
TaskPushNotificationConfig,
4041
TaskQueryParams,
42+
TaskState,
4143
UnsupportedOperationError,
4244
)
4345
from a2a.utils.errors import ServerError
@@ -178,6 +180,18 @@ async def on_message_send(
178180
)
179181
task: Task | None = await task_manager.get_task()
180182
if task:
183+
if task.status.state in {
184+
TaskState.completed,
185+
TaskState.canceled,
186+
TaskState.failed,
187+
TaskState.rejected,
188+
}:
189+
raise ServerError(
190+
error=InvalidParamsError(
191+
message=f'Task {task.id} is in terminal state: {task.status.state}'
192+
)
193+
)
194+
181195
task = task_manager.update_with_message(params.message, task)
182196
if self.should_add_push_info(params):
183197
assert isinstance(self._push_notifier, PushNotifier)
@@ -264,8 +278,19 @@ async def on_message_send_stream(
264278
task: Task | None = await task_manager.get_task()
265279

266280
if task:
267-
task = task_manager.update_with_message(params.message, task)
281+
if task.status.state in {
282+
TaskState.completed,
283+
TaskState.canceled,
284+
TaskState.failed,
285+
TaskState.rejected,
286+
}:
287+
raise ServerError(
288+
error=InvalidParamsError(
289+
message=f'Task {task.id} is in terminal state: {task.status.state}'
290+
)
291+
)
268292

293+
task = task_manager.update_with_message(params.message, task)
269294
if self.should_add_push_info(params):
270295
assert isinstance(self._push_notifier, PushNotifier)
271296
assert isinstance(
@@ -413,6 +438,18 @@ async def on_resubscribe_to_task(
413438
if not task:
414439
raise ServerError(error=TaskNotFoundError())
415440

441+
if task.status.state in {
442+
TaskState.completed,
443+
TaskState.canceled,
444+
TaskState.failed,
445+
TaskState.rejected,
446+
}:
447+
raise ServerError(
448+
error=InvalidParamsError(
449+
message=f'Task {task.id} is in terminal state: {task.status.state}'
450+
)
451+
)
452+
416453
task_manager = TaskManager(
417454
task_id=task.id,
418455
context_id=task.contextId,

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from a2a.types import (
3030
InternalError,
31+
InvalidParamsError,
3132
Message,
3233
MessageSendConfiguration,
3334
MessageSendParams,
@@ -1137,3 +1138,138 @@ async def consume_stream():
11371138

11381139
texts = [p.root.text for e in events for p in e.status.message.parts]
11391140
assert texts == ['Event 0', 'Event 1', 'Event 2']
1141+
1142+
1143+
TERMINAL_STATES = [
1144+
TaskState.completed,
1145+
TaskState.canceled,
1146+
TaskState.failed,
1147+
TaskState.rejected,
1148+
]
1149+
1150+
1151+
@pytest.mark.asyncio
1152+
@pytest.mark.parametrize('terminal_state', TERMINAL_STATES)
1153+
async def test_on_message_send_task_in_terminal_state(terminal_state):
1154+
"""Test on_message_send when task is already in a terminal state."""
1155+
task_id = f'terminal_task_{terminal_state.value}'
1156+
terminal_task = create_sample_task(
1157+
task_id=task_id, status_state=terminal_state
1158+
)
1159+
1160+
mock_task_store = AsyncMock(spec=TaskStore)
1161+
# The get method of TaskManager calls task_store.get.
1162+
# We mock TaskManager.get_task which is an async method.
1163+
# So we should patch that instead.
1164+
1165+
request_handler = DefaultRequestHandler(
1166+
agent_executor=DummyAgentExecutor(), task_store=mock_task_store
1167+
)
1168+
1169+
params = MessageSendParams(
1170+
message=Message(
1171+
role=Role.user,
1172+
messageId='msg_terminal',
1173+
parts=[],
1174+
taskId=task_id,
1175+
)
1176+
)
1177+
1178+
from a2a.utils.errors import ServerError
1179+
1180+
# Patch the TaskManager's get_task method to return our terminal task
1181+
with patch(
1182+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
1183+
return_value=terminal_task,
1184+
):
1185+
with pytest.raises(ServerError) as exc_info:
1186+
await request_handler.on_message_send(
1187+
params, create_server_call_context()
1188+
)
1189+
1190+
assert isinstance(exc_info.value.error, InvalidParamsError)
1191+
assert exc_info.value.error.message
1192+
assert (
1193+
f'Task {task_id} is in terminal state: {terminal_state.value}'
1194+
in exc_info.value.error.message
1195+
)
1196+
1197+
1198+
@pytest.mark.asyncio
1199+
@pytest.mark.parametrize('terminal_state', TERMINAL_STATES)
1200+
async def test_on_message_send_stream_task_in_terminal_state(terminal_state):
1201+
"""Test on_message_send_stream when task is already in a terminal state."""
1202+
task_id = f'terminal_stream_task_{terminal_state.value}'
1203+
terminal_task = create_sample_task(
1204+
task_id=task_id, status_state=terminal_state
1205+
)
1206+
1207+
mock_task_store = AsyncMock(spec=TaskStore)
1208+
1209+
request_handler = DefaultRequestHandler(
1210+
agent_executor=DummyAgentExecutor(), task_store=mock_task_store
1211+
)
1212+
1213+
params = MessageSendParams(
1214+
message=Message(
1215+
role=Role.user,
1216+
messageId='msg_terminal_stream',
1217+
parts=[],
1218+
taskId=task_id,
1219+
)
1220+
)
1221+
1222+
from a2a.utils.errors import ServerError
1223+
1224+
with patch(
1225+
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
1226+
return_value=terminal_task,
1227+
):
1228+
with pytest.raises(ServerError) as exc_info:
1229+
async for _ in request_handler.on_message_send_stream(
1230+
params, create_server_call_context()
1231+
):
1232+
pass # pragma: no cover
1233+
1234+
assert isinstance(exc_info.value.error, InvalidParamsError)
1235+
assert exc_info.value.error.message
1236+
assert (
1237+
f'Task {task_id} is in terminal state: {terminal_state.value}'
1238+
in exc_info.value.error.message
1239+
)
1240+
1241+
1242+
@pytest.mark.asyncio
1243+
@pytest.mark.parametrize('terminal_state', TERMINAL_STATES)
1244+
async def test_on_resubscribe_to_task_in_terminal_state(terminal_state):
1245+
"""Test on_resubscribe_to_task when task is in a terminal state."""
1246+
task_id = f'resub_terminal_task_{terminal_state.value}'
1247+
terminal_task = create_sample_task(
1248+
task_id=task_id, status_state=terminal_state
1249+
)
1250+
1251+
mock_task_store = AsyncMock(spec=TaskStore)
1252+
mock_task_store.get.return_value = terminal_task
1253+
1254+
request_handler = DefaultRequestHandler(
1255+
agent_executor=DummyAgentExecutor(),
1256+
task_store=mock_task_store,
1257+
queue_manager=AsyncMock(spec=QueueManager),
1258+
)
1259+
params = TaskIdParams(id=task_id)
1260+
1261+
from a2a.utils.errors import ServerError
1262+
1263+
with pytest.raises(ServerError) as exc_info:
1264+
async for _ in request_handler.on_resubscribe_to_task(
1265+
params, create_server_call_context()
1266+
):
1267+
pass # pragma: no cover
1268+
1269+
assert isinstance(exc_info.value.error, InvalidParamsError)
1270+
assert exc_info.value.error.message
1271+
assert (
1272+
f'Task {task_id} is in terminal state: {terminal_state.value}'
1273+
in exc_info.value.error.message
1274+
)
1275+
mock_task_store.get.assert_awaited_once_with(task_id)

0 commit comments

Comments
 (0)