|
| 1 | +from typing import Any |
| 2 | +import pytest |
| 3 | +import httpx |
| 4 | +from uuid import uuid4 |
| 5 | + |
| 6 | +from a2a.client import A2ACardResolver, ClientConfig, ClientFactory |
| 7 | +from a2a.types import Message, Part, Role, TextPart |
| 8 | + |
| 9 | + |
| 10 | +# A2A validation helpers - adapted from https://github.com/a2aproject/a2a-inspector/blob/main/backend/validators.py |
| 11 | + |
| 12 | +def validate_agent_card(card_data: dict[str, Any]) -> list[str]: |
| 13 | + """Validate the structure and fields of an agent card.""" |
| 14 | + errors: list[str] = [] |
| 15 | + |
| 16 | + # Use a frozenset for efficient checking and to indicate immutability. |
| 17 | + required_fields = frozenset( |
| 18 | + [ |
| 19 | + 'name', |
| 20 | + 'description', |
| 21 | + 'url', |
| 22 | + 'version', |
| 23 | + 'capabilities', |
| 24 | + 'defaultInputModes', |
| 25 | + 'defaultOutputModes', |
| 26 | + 'skills', |
| 27 | + ] |
| 28 | + ) |
| 29 | + |
| 30 | + # Check for the presence of all required fields |
| 31 | + for field in required_fields: |
| 32 | + if field not in card_data: |
| 33 | + errors.append(f"Required field is missing: '{field}'.") |
| 34 | + |
| 35 | + # Check if 'url' is an absolute URL (basic check) |
| 36 | + if 'url' in card_data and not ( |
| 37 | + card_data['url'].startswith('http://') |
| 38 | + or card_data['url'].startswith('https://') |
| 39 | + ): |
| 40 | + errors.append( |
| 41 | + "Field 'url' must be an absolute URL starting with http:// or https://." |
| 42 | + ) |
| 43 | + |
| 44 | + # Check if capabilities is a dictionary |
| 45 | + if 'capabilities' in card_data and not isinstance( |
| 46 | + card_data['capabilities'], dict |
| 47 | + ): |
| 48 | + errors.append("Field 'capabilities' must be an object.") |
| 49 | + |
| 50 | + # Check if defaultInputModes and defaultOutputModes are arrays of strings |
| 51 | + for field in ['defaultInputModes', 'defaultOutputModes']: |
| 52 | + if field in card_data: |
| 53 | + if not isinstance(card_data[field], list): |
| 54 | + errors.append(f"Field '{field}' must be an array of strings.") |
| 55 | + elif not all(isinstance(item, str) for item in card_data[field]): |
| 56 | + errors.append(f"All items in '{field}' must be strings.") |
| 57 | + |
| 58 | + # Check skills array |
| 59 | + if 'skills' in card_data: |
| 60 | + if not isinstance(card_data['skills'], list): |
| 61 | + errors.append( |
| 62 | + "Field 'skills' must be an array of AgentSkill objects." |
| 63 | + ) |
| 64 | + elif not card_data['skills']: |
| 65 | + errors.append( |
| 66 | + "Field 'skills' array is empty. Agent must have at least one skill if it performs actions." |
| 67 | + ) |
| 68 | + |
| 69 | + return errors |
| 70 | + |
| 71 | + |
| 72 | +def _validate_task(data: dict[str, Any]) -> list[str]: |
| 73 | + errors = [] |
| 74 | + if 'id' not in data: |
| 75 | + errors.append("Task object missing required field: 'id'.") |
| 76 | + if 'status' not in data or 'state' not in data.get('status', {}): |
| 77 | + errors.append("Task object missing required field: 'status.state'.") |
| 78 | + return errors |
| 79 | + |
| 80 | + |
| 81 | +def _validate_status_update(data: dict[str, Any]) -> list[str]: |
| 82 | + errors = [] |
| 83 | + if 'status' not in data or 'state' not in data.get('status', {}): |
| 84 | + errors.append( |
| 85 | + "StatusUpdate object missing required field: 'status.state'." |
| 86 | + ) |
| 87 | + return errors |
| 88 | + |
| 89 | + |
| 90 | +def _validate_artifact_update(data: dict[str, Any]) -> list[str]: |
| 91 | + errors = [] |
| 92 | + if 'artifact' not in data: |
| 93 | + errors.append( |
| 94 | + "ArtifactUpdate object missing required field: 'artifact'." |
| 95 | + ) |
| 96 | + elif ( |
| 97 | + 'parts' not in data.get('artifact', {}) |
| 98 | + or not isinstance(data.get('artifact', {}).get('parts'), list) |
| 99 | + or not data.get('artifact', {}).get('parts') |
| 100 | + ): |
| 101 | + errors.append("Artifact object must have a non-empty 'parts' array.") |
| 102 | + return errors |
| 103 | + |
| 104 | + |
| 105 | +def _validate_message(data: dict[str, Any]) -> list[str]: |
| 106 | + errors = [] |
| 107 | + if ( |
| 108 | + 'parts' not in data |
| 109 | + or not isinstance(data.get('parts'), list) |
| 110 | + or not data.get('parts') |
| 111 | + ): |
| 112 | + errors.append("Message object must have a non-empty 'parts' array.") |
| 113 | + if 'role' not in data or data.get('role') != 'agent': |
| 114 | + errors.append("Message from agent must have 'role' set to 'agent'.") |
| 115 | + return errors |
| 116 | + |
| 117 | + |
| 118 | +def validate_event(data: dict[str, Any]) -> list[str]: |
| 119 | + """Validate an incoming event from the agent based on its kind.""" |
| 120 | + if 'kind' not in data: |
| 121 | + return ["Response from agent is missing required 'kind' field."] |
| 122 | + |
| 123 | + kind = data.get('kind') |
| 124 | + validators = { |
| 125 | + 'task': _validate_task, |
| 126 | + 'status-update': _validate_status_update, |
| 127 | + 'artifact-update': _validate_artifact_update, |
| 128 | + 'message': _validate_message, |
| 129 | + } |
| 130 | + |
| 131 | + validator = validators.get(str(kind)) |
| 132 | + if validator: |
| 133 | + return validator(data) |
| 134 | + |
| 135 | + return [f"Unknown message kind received: '{kind}'."] |
| 136 | + |
| 137 | + |
| 138 | +# A2A messaging helpers |
| 139 | + |
| 140 | +async def send_text_message(text: str, url: str, context_id: str | None = None, streaming: bool = False): |
| 141 | + async with httpx.AsyncClient(timeout=10) as httpx_client: |
| 142 | + resolver = A2ACardResolver(httpx_client=httpx_client, base_url=url) |
| 143 | + agent_card = await resolver.get_agent_card() |
| 144 | + config = ClientConfig(httpx_client=httpx_client, streaming=streaming) |
| 145 | + factory = ClientFactory(config) |
| 146 | + client = factory.create(agent_card) |
| 147 | + |
| 148 | + msg = Message( |
| 149 | + kind="message", |
| 150 | + role=Role.user, |
| 151 | + parts=[Part(TextPart(text=text))], |
| 152 | + message_id=uuid4().hex, |
| 153 | + context_id=context_id, |
| 154 | + ) |
| 155 | + |
| 156 | + events = [event async for event in client.send_message(msg)] |
| 157 | + |
| 158 | + return events |
| 159 | + |
| 160 | + |
| 161 | +# A2A conformance tests |
| 162 | + |
| 163 | +def test_agent_card(agent): |
| 164 | + """Validate agent card structure and required fields.""" |
| 165 | + response = httpx.get(f"{agent}/.well-known/agent-card.json") |
| 166 | + assert response.status_code == 200, "Agent card endpoint must return 200" |
| 167 | + |
| 168 | + card_data = response.json() |
| 169 | + errors = validate_agent_card(card_data) |
| 170 | + |
| 171 | + assert not errors, f"Agent card validation failed:\n" + "\n".join(errors) |
| 172 | + |
| 173 | +@pytest.mark.asyncio |
| 174 | +@pytest.mark.parametrize("streaming", [True, False]) |
| 175 | +async def test_message(agent, streaming): |
| 176 | + """Test that agent returns valid A2A message format.""" |
| 177 | + events = await send_text_message("Hello", agent, streaming=streaming) |
| 178 | + |
| 179 | + all_errors = [] |
| 180 | + for event in events: |
| 181 | + match event: |
| 182 | + case Message() as msg: |
| 183 | + errors = validate_event(msg.model_dump()) |
| 184 | + all_errors.extend(errors) |
| 185 | + |
| 186 | + case (task, update): |
| 187 | + errors = validate_event(task.model_dump()) |
| 188 | + all_errors.extend(errors) |
| 189 | + if update: |
| 190 | + errors = validate_event(update.model_dump()) |
| 191 | + all_errors.extend(errors) |
| 192 | + |
| 193 | + case _: |
| 194 | + pytest.fail(f"Unexpected event type: {type(event)}") |
| 195 | + |
| 196 | + assert events, "Agent should respond with at least one event" |
| 197 | + assert not all_errors, f"Message validation failed:\n" + "\n".join(all_errors) |
| 198 | + |
| 199 | +# Add your custom tests here |
0 commit comments