diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 60272367..22556cde 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -28,22 +28,10 @@ def new_task(request: Message) -> Task: if isinstance(part.root, TextPart) and not part.root.text: raise ValueError('TextPart content cannot be empty') - context_id_str = request.context_id - if context_id_str is not None: - try: - uuid.UUID(context_id_str) - context_id = context_id_str - except (ValueError, AttributeError, TypeError) as e: - raise ValueError( - f"Invalid context_id: '{context_id_str}' is not a valid UUID." - ) from e - else: - context_id = str(uuid.uuid4()) - return Task( status=TaskStatus(state=TaskState.submitted), - id=(request.task_id if request.task_id else str(uuid.uuid4())), - context_id=context_id, + id=request.task_id or str(uuid.uuid4()), + context_id=request.context_id or str(uuid.uuid4()), history=[request], ) diff --git a/tests/utils/test_task.py b/tests/utils/test_task.py index 77441316..cb3dc386 100644 --- a/tests/utils/test_task.py +++ b/tests/utils/test_task.py @@ -188,24 +188,6 @@ def test_completed_task_invalid_artifact_type(self): history=[], ) - def test_new_task_with_invalid_context_id(self): - """Test that new_task raises a ValueError for various invalid context_id formats.""" - invalid_ids = ['not-a-uuid', ''] - for invalid_id in invalid_ids: - with self.subTest(invalid_id=invalid_id): - with pytest.raises( - ValueError, - match=f"Invalid context_id: '{invalid_id}' is not a valid UUID.", - ): - new_task( - Message( - role=Role.user, - parts=[Part(root=TextPart(text='test message'))], - message_id=str(uuid.uuid4()), - context_id=invalid_id, - ) - ) - if __name__ == '__main__': unittest.main()