Skip to content

Commit 6e3451d

Browse files
committed
define terminal states in a constant
1 parent 7f8a6ba commit 6e3451d

File tree

2 files changed

+14
-25
lines changed

2 files changed

+14
-25
lines changed

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@
4848

4949
logger = logging.getLogger(__name__)
5050

51+
TERMINAL_TASK_STATES = {
52+
TaskState.completed,
53+
TaskState.canceled,
54+
TaskState.failed,
55+
TaskState.rejected,
56+
}
5157

5258
@trace_class(kind=SpanKind.SERVER)
5359
class DefaultRequestHandler(RequestHandler):
@@ -180,12 +186,7 @@ async def on_message_send(
180186
)
181187
task: Task | None = await task_manager.get_task()
182188
if task:
183-
if task.status.state in {
184-
TaskState.completed,
185-
TaskState.canceled,
186-
TaskState.failed,
187-
TaskState.rejected,
188-
}:
189+
if task.status.state in TERMINAL_TASK_STATES:
189190
raise ServerError(
190191
error=InvalidParamsError(
191192
message=f'Task {task.id} is in terminal state: {task.status.state}'
@@ -278,12 +279,7 @@ async def on_message_send_stream(
278279
task: Task | None = await task_manager.get_task()
279280

280281
if task:
281-
if task.status.state in {
282-
TaskState.completed,
283-
TaskState.canceled,
284-
TaskState.failed,
285-
TaskState.rejected,
286-
}:
282+
if task.status.state in TERMINAL_TASK_STATES:
287283
raise ServerError(
288284
error=InvalidParamsError(
289285
message=f'Task {task.id} is in terminal state: {task.status.state}'
@@ -438,12 +434,7 @@ async def on_resubscribe_to_task(
438434
if not task:
439435
raise ServerError(error=TaskNotFoundError())
440436

441-
if task.status.state in {
442-
TaskState.completed,
443-
TaskState.canceled,
444-
TaskState.failed,
445-
TaskState.rejected,
446-
}:
437+
if task.status.state in TERMINAL_TASK_STATES:
447438
raise ServerError(
448439
error=InvalidParamsError(
449440
message=f'Task {task.id} is in terminal state: {task.status.state}'

tests/server/request_handlers/test_default_request_handler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1139,17 +1139,15 @@ async def consume_stream():
11391139
texts = [p.root.text for e in events for p in e.status.message.parts]
11401140
assert texts == ['Event 0', 'Event 1', 'Event 2']
11411141

1142-
1143-
TERMINAL_STATES = [
1142+
TERMINAL_TASK_STATES = {
11441143
TaskState.completed,
11451144
TaskState.canceled,
11461145
TaskState.failed,
11471146
TaskState.rejected,
1148-
]
1149-
1147+
}
11501148

11511149
@pytest.mark.asyncio
1152-
@pytest.mark.parametrize('terminal_state', TERMINAL_STATES)
1150+
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
11531151
async def test_on_message_send_task_in_terminal_state(terminal_state):
11541152
"""Test on_message_send when task is already in a terminal state."""
11551153
task_id = f'terminal_task_{terminal_state.value}'
@@ -1196,7 +1194,7 @@ async def test_on_message_send_task_in_terminal_state(terminal_state):
11961194

11971195

11981196
@pytest.mark.asyncio
1199-
@pytest.mark.parametrize('terminal_state', TERMINAL_STATES)
1197+
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
12001198
async def test_on_message_send_stream_task_in_terminal_state(terminal_state):
12011199
"""Test on_message_send_stream when task is already in a terminal state."""
12021200
task_id = f'terminal_stream_task_{terminal_state.value}'
@@ -1240,7 +1238,7 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state):
12401238

12411239

12421240
@pytest.mark.asyncio
1243-
@pytest.mark.parametrize('terminal_state', TERMINAL_STATES)
1241+
@pytest.mark.parametrize('terminal_state', TERMINAL_TASK_STATES)
12441242
async def test_on_resubscribe_to_task_in_terminal_state(terminal_state):
12451243
"""Test on_resubscribe_to_task when task is in a terminal state."""
12461244
task_id = f'resub_terminal_task_{terminal_state.value}'

0 commit comments

Comments
 (0)