diff --git a/tests/eval/eval_sample_a2a_expense_reimburse_agent_server.ipynb b/tests/eval/eval_sample_a2a_expense_reimburse_agent_server.ipynb new file mode 100644 index 00000000..07135975 --- /dev/null +++ b/tests/eval/eval_sample_a2a_expense_reimburse_agent_server.ipynb @@ -0,0 +1,1148 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ouFs9L8c5cx5" + }, + "source": [ + "## Vertex AI End-to-End Evaluation for A2A Reimbursement Agent (in memory server)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l1NpzAqE5cx6" + }, + "source": [ + "This notebook demonstrates how to evaluate a Reimbursement A2A Agent using Vertex AI Evaluation services.\n", + "\n", + "**Updated**: jbu@google, 06/10/2025\n", + "\n", + "**Prerequisites:**\n", + "1. **Google Cloud Project:** You need a Google Cloud Project with the Vertex AI API enabled.\n", + "2. **Authentication:** You need to be authenticated to Google Cloud. In a Colab environment, this is usually handled by running `from google.colab import auth` and `auth.authenticate_user()`.\n", + "3. **Agent Logic:** The core logic for the Reimbursement Agent (e.g., a `ReimbursementAgentExecutor` class) must be defined or importable within this notebook. This executor should have a method like `async def execute(self, message_payload: a2a.types.MessagePayload) -> a2a.types.Message:`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5HABRB6L5cx6" + }, + "source": [ + "### 1. Setup and Installs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2o7_ic4K5cx6" + }, + "outputs": [], + "source": [ + "!pip install google-cloud-aiplatform httpx \"a2a-sdk==0.2.5\" --quiet\n", + "!pip install --upgrade --quiet 'google-adk==1.2.0'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qxQocCV05cx6" + }, + "outputs": [], + "source": [ + "import os\n", + "import uuid\n", + "import json\n", + "import asyncio\n", + "from typing import Any, Dict, List, Optional\n", + "\n", + "from google.colab import auth\n", + "from google.cloud import aiplatform\n", + "\n", + "import a2a.types as a2a_types\n", + "from a2a.server.tasks import InMemoryTaskStore\n", + "from uuid import uuid4" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ogSu_6Pe5cx7" + }, + "source": [ + "### 2. Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_c3_Me9D5cx7" + }, + "outputs": [], + "source": [ + "# --- Google Cloud Configuration ---\n", + "PROJECT_ID = \"\" # @param {type: \"string\", placeholder: \"[your-project-id]\", isTemplate: true}\n", + "LOCATION = \"us-central1\" # @param {type:\"string\"} Fill in your Google Cloud region\n", + "\n", + "# --- Authentication (for Colab) ---\n", + "if not PROJECT_ID:\n", + " raise ValueError(\"Please set your PROJECT_ID.\")\n", + "os.environ[\"GOOGLE_CLOUD_PROJECT\"] = PROJECT_ID\n", + "\n", + "try:\n", + " auth.authenticate_user()\n", + " print(\"Colab user authenticated.\")\n", + "except Exception as e:\n", + " print(f\"Not in a Colab environment or auth failed: {e}. Assuming local gcloud auth.\")\n", + "\n", + "aiplatform.init(project=PROJECT_ID, location=LOCATION)" + ] + }, + { + "cell_type": "code", + "source": [ + "EXPERIMENT_NAME = \"evaluate-a2a\" # @param {type:\"string\"}\n", + "BUCKET_NAME = \"a2a-sdk-eval\" # @param {type: \"string\", placeholder: \"[your-bucket-name]\", isTemplate: true}\n", + "BUCKET_URI = f\"gs://{BUCKET_NAME}\"" + ], + "metadata": { + "id": "obMo8ht8-b0J" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import asyncio\n", + "import json\n", + "import logging\n", + "from collections.abc import AsyncGenerator\n", + "from typing import Any\n", + "\n", + "# A2A Core Imports\n", + "from a2a.server.apps.starlette_app import A2AStarletteApplication\n", + "from a2a.server.context import ServerCallContext\n", + "from a2a.server.request_handlers.request_handler import RequestHandler\n", + "from a2a.server.request_handlers.default_request_handler import DefaultRequestHandler\n", + "from a2a.types import (\n", + " AgentSkill, AgentCard, AgentCapabilities, Artifact, CancelTaskRequest,\n", + " CancelTaskResponse, GetTaskPushNotificationConfigRequest,\n", + " GetTaskPushNotificationConfigResponse, GetTaskRequest, GetTaskResponse,\n", + " Message, Part, SendMessageRequest, SendMessageResponse, MessageSendParams,\n", + " SendStreamingMessageRequest, SendStreamingMessageResponse,\n", + " SetTaskPushNotificationConfigRequest, SetTaskPushNotificationConfigResponse,\n", + " Task, TaskArtifactUpdateEvent, TaskState, TaskStatus, TextPart,\n", + " PushNotificationConfig, TaskPushNotificationConfig, TaskQueryParams\n", + "# TaskResubscriptionRequest,TaskResubscriptionResponse,\n", + ")\n", + "from a2a.utils.errors import MethodNotImplementedError # For unhandled methods\n", + "\n", + "# Starlette Test Client\n", + "from starlette.testclient import TestClient" + ], + "metadata": { + "id": "u8mkWCHaxqLc" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Defining Reimbursement Agent\n", + "\n", + "import json\n", + "import random\n", + "from typing import Any, AsyncIterable, Optional\n", + "from google.adk.agents.llm_agent import LlmAgent\n", + "from google.adk.artifacts import InMemoryArtifactService\n", + "from google.adk.memory.in_memory_memory_service import InMemoryMemoryService\n", + "from google.adk.runners import Runner\n", + "from google.adk.sessions import InMemorySessionService\n", + "from google.adk.tools.tool_context import ToolContext\n", + "from google.genai import types\n", + "\n", + "\n", + "# Local cache of created request_ids for demo purposes.\n", + "request_ids = set()\n", + "\n", + "\n", + "def create_request_form(\n", + " date: Optional[str] = None,\n", + " amount: Optional[str] = None,\n", + " purpose: Optional[str] = None,\n", + ") -> dict[str, Any]:\n", + " \"\"\"\n", + " Create a request form for the employee to fill out.\n", + "\n", + " Args:\n", + " date (str): The date of the request. Can be an empty string.\n", + " amount (str): The requested amount. Can be an empty string.\n", + " purpose (str): The purpose of the request. Can be an empty string.\n", + "\n", + " Returns:\n", + " dict[str, Any]: A dictionary containing the request form data.\n", + " \"\"\"\n", + " request_id = 'request_id_' + str(random.randint(1000000, 9999999))\n", + " request_ids.add(request_id)\n", + " return {\n", + " 'request_id': request_id,\n", + " 'date': '' if not date else date,\n", + " 'amount': '' if not amount else amount,\n", + " 'purpose': ''\n", + " if not purpose\n", + " else purpose,\n", + " }\n", + "\n", + "\n", + "def return_form(\n", + " form_request: dict[str, Any],\n", + " tool_context: ToolContext,\n", + " instructions: Optional[str] = None,\n", + ") -> dict[str, Any]:\n", + " \"\"\"\n", + " Returns a structured json object indicating a form to complete.\n", + "\n", + " Args:\n", + " form_request (dict[str, Any]): The request form data.\n", + " tool_context (ToolContext): The context in which the tool operates.\n", + " instructions (str): Instructions for processing the form. Can be an empty string.\n", + "\n", + " Returns:\n", + " dict[str, Any]: A JSON dictionary for the form response.\n", + " \"\"\"\n", + " if isinstance(form_request, str):\n", + " form_request = json.loads(form_request)\n", + "\n", + " tool_context.actions.skip_summarization = True\n", + " tool_context.actions.escalate = True\n", + " form_dict = {\n", + " 'type': 'form',\n", + " 'form': {\n", + " 'type': 'object',\n", + " 'properties': {\n", + " 'date': {\n", + " 'type': 'string',\n", + " 'format': 'date',\n", + " 'description': 'Date of expense',\n", + " 'title': 'Date',\n", + " },\n", + " 'amount': {\n", + " 'type': 'string',\n", + " 'format': 'number',\n", + " 'description': 'Amount of expense',\n", + " 'title': 'Amount',\n", + " },\n", + " 'purpose': {\n", + " 'type': 'string',\n", + " 'description': 'Purpose of expense',\n", + " 'title': 'Purpose',\n", + " },\n", + " 'request_id': {\n", + " 'type': 'string',\n", + " 'description': 'Request id',\n", + " 'title': 'Request ID',\n", + " },\n", + " },\n", + " 'required': list(form_request.keys()),\n", + " },\n", + " 'form_data': form_request,\n", + " 'instructions': instructions,\n", + " }\n", + " return json.dumps(form_dict)\n", + "\n", + "\n", + "def reimburse(request_id: str) -> dict[str, Any]:\n", + " \"\"\"Reimburse the amount of money to the employee for a given request_id.\"\"\"\n", + " if request_id not in request_ids:\n", + " return {\n", + " 'request_id': request_id,\n", + " 'status': 'Error: Invalid request_id.',\n", + " }\n", + " return {'request_id': request_id, 'status': 'approved'}\n", + "\n", + "\n", + "class ReimbursementAgent:\n", + " \"\"\"An agent that handles reimbursement requests.\"\"\"\n", + "\n", + " SUPPORTED_CONTENT_TYPES = ['text', 'text/plain']\n", + "\n", + " def __init__(self):\n", + " self._agent = self._build_agent()\n", + " self._user_id = 'remote_agent'\n", + " self._runner = Runner(\n", + " app_name=self._agent.name,\n", + " agent=self._agent,\n", + " artifact_service=InMemoryArtifactService(),\n", + " session_service=InMemorySessionService(),\n", + " memory_service=InMemoryMemoryService(),\n", + " )\n", + "\n", + " def get_processing_message(self) -> str:\n", + " return 'Processing the reimbursement request...'\n", + "\n", + " def _build_agent(self) -> LlmAgent:\n", + " \"\"\"Builds the LLM agent for the reimbursement agent.\"\"\"\n", + " return LlmAgent(\n", + " model='gemini-2.0-flash-001',\n", + " name='reimbursement_agent',\n", + " description=(\n", + " 'This agent handles the reimbursement process for the employees'\n", + " ' given the amount and purpose of the reimbursement.'\n", + " ),\n", + " instruction=\"\"\"\n", + " You are an agent who handles the reimbursement process for employees.\n", + "\n", + " When you receive a reimbursement request, you should first create a new request form using create_request_form(). Only provide default values if they are provided by the user, otherwise use an empty string as the default value.\n", + " 1. 'Date': the date of the transaction.\n", + " 2. 'Amount': the dollar amount of the transaction.\n", + " 3. 'Business Justification/Purpose': the reason for the reimbursement.\n", + "\n", + " Once you created the form, you should return the result of calling return_form with the form data from the create_request_form call.\n", + "\n", + " Once you received the filled-out form back from the user, you should then check the form contains all required information:\n", + " 1. 'Date': the date of the transaction.\n", + " 2. 'Amount': the value of the amount of the reimbursement being requested.\n", + " 3. 'Business Justification/Purpose': the item/object/artifact of the reimbursement.\n", + "\n", + " If you don't have all of the information, you should reject the request directly by calling the request_form method, providing the missing fields.\n", + "\n", + "\n", + " For valid reimbursement requests, you can then use reimburse() to reimburse the employee.\n", + " * In your response, you should include the request_id and the status of the reimbursement request.\n", + "\n", + " \"\"\",\n", + " tools=[\n", + " create_request_form,\n", + " reimburse,\n", + " return_form,\n", + " ],\n", + " )\n", + "\n", + " async def stream(self, query, session_id) -> AsyncIterable[dict[str, Any]]:\n", + " session = await self._runner.session_service.get_session(\n", + " app_name=self._agent.name,\n", + " user_id=self._user_id,\n", + " session_id=session_id,\n", + " )\n", + " content = types.Content(\n", + " role='user', parts=[types.Part.from_text(text=query)]\n", + " )\n", + " if session is None:\n", + " session = await self._runner.session_service.create_session(\n", + " app_name=self._agent.name,\n", + " user_id=self._user_id,\n", + " state={},\n", + " session_id=session_id,\n", + " )\n", + " async for event in self._runner.run_async(\n", + " user_id=self._user_id, session_id=session.id, new_message=content\n", + " ):\n", + " if event.is_final_response():\n", + " response = ''\n", + " if (\n", + " event.content\n", + " and event.content.parts\n", + " and event.content.parts[0].text\n", + " ):\n", + " response = '\\n'.join(\n", + " [p.text for p in event.content.parts if p.text]\n", + " )\n", + " elif (\n", + " event.content\n", + " and event.content.parts\n", + " and any(\n", + " [\n", + " True\n", + " for p in event.content.parts\n", + " if p.function_response\n", + " ]\n", + " )\n", + " ):\n", + " response = next(\n", + " p.function_response.model_dump()\n", + " for p in event.content.parts\n", + " )\n", + " yield {\n", + " 'is_task_complete': True,\n", + " 'content': response,\n", + " }\n", + " else:\n", + " yield {\n", + " 'is_task_complete': False,\n", + " 'updates': self.get_processing_message(),\n", + " }" + ], + "metadata": { + "id": "2NGW4v25_FoO" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Implement AgentExecutor\n", + "\n", + "import json\n", + "from a2a.server.agent_execution import AgentExecutor, RequestContext\n", + "from a2a.server.events import EventQueue\n", + "from a2a.server.tasks import TaskUpdater\n", + "from a2a.types import (\n", + " DataPart,\n", + " Part,\n", + " Task,\n", + " TaskState,\n", + " TextPart,\n", + " UnsupportedOperationError,\n", + ")\n", + "from a2a.utils import (\n", + " new_agent_parts_message,\n", + " new_agent_text_message,\n", + " new_task,\n", + ")\n", + "from a2a.utils.errors import ServerError\n", + "\n", + "class ReimbursementAgentExecutor(AgentExecutor):\n", + " \"\"\"Reimbursement AgentExecutor Example.\"\"\"\n", + "\n", + " def __init__(self):\n", + " self.agent = ReimbursementAgent()\n", + "\n", + " async def execute(\n", + " self,\n", + " context: RequestContext,\n", + " event_queue: EventQueue,\n", + " ) -> None:\n", + " query = context.get_user_input()\n", + " task = context.current_task\n", + "\n", + " # This agent always produces Task objects. If this request does\n", + " # not have current task, create a new one and use it.\n", + " if not task:\n", + " task = new_task(context.message)\n", + " event_queue.enqueue_event(task)\n", + " updater = TaskUpdater(event_queue, task.id, task.contextId)\n", + " # invoke the underlying agent, using streaming results. The streams\n", + " # now are update events.\n", + " async for item in self.agent.stream(query, task.contextId):\n", + " is_task_complete = item['is_task_complete']\n", + " artifacts = None\n", + " if not is_task_complete:\n", + " updater.update_status(\n", + " TaskState.working,\n", + " new_agent_text_message(\n", + " item['updates'], task.contextId, task.id\n", + " ),\n", + " )\n", + " continue\n", + " # If the response is a dictionary, assume its a form\n", + " if isinstance(item['content'], dict):\n", + " # Verify it is a valid form\n", + " if (\n", + " 'response' in item['content']\n", + " and 'result' in item['content']['response']\n", + " ):\n", + " data = json.loads(item['content']['response']['result'])\n", + " updater.update_status(\n", + " TaskState.input_required,\n", + " new_agent_parts_message(\n", + " [Part(root=DataPart(data=data))],\n", + " task.contextId,\n", + " task.id,\n", + " ),\n", + " final=True,\n", + " )\n", + " continue\n", + " else:\n", + " updater.update_status(\n", + " TaskState.failed,\n", + " new_agent_text_message(\n", + " 'Reaching an unexpected state',\n", + " task.contextId,\n", + " task.id,\n", + " ),\n", + " final=True,\n", + " )\n", + " break\n", + " else:\n", + " # Emit the appropriate events\n", + " updater.add_artifact(\n", + " [Part(root=TextPart(text=item['content']))], name='form'\n", + " )\n", + " updater.complete()\n", + " break\n", + "\n", + " async def cancel(\n", + " self, request: RequestContext, event_queue: EventQueue\n", + " ) -> Task | None:\n", + " raise ServerError(error=UnsupportedOperationError())" + ], + "metadata": { + "id": "oj-cvRtzyHof" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "os.environ[\"GOOGLE_CLOUD_PROJECT\"] = PROJECT_ID\n", + "os.environ[\"GOOGLE_CLOUD_LOCATION\"] = LOCATION\n", + "os.environ[\"GOOGLE_GENAI_USE_VERTEXAI\"] = \"True\"" + ], + "metadata": { + "id": "5qABW6Yb_OEn" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "capabilities = AgentCapabilities(streaming=True)\n", + "skill = AgentSkill(\n", + " id='process_reimbursement',\n", + " name='Process Reimbursement Tool',\n", + " description='Helps with the reimbursement process for users given the amount and purpose of the reimbursement.',\n", + " tags=['reimbursement'],\n", + " examples=[\n", + " 'Can you reimburse me $20 for my lunch with the clients?'\n", + " ],\n", + ")\n", + "agent_card = AgentCard(\n", + " name='Reimbursement Agent',\n", + " description='This agent handles the reimbursement process for the employees given the amount and purpose of the reimbursement.',\n", + " url='http://localhost/agent', # Placeholder, not used by TestClient\n", + " # url=f'http://{host}:{port}/',\n", + " version='1.0.0',\n", + " defaultInputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES,\n", + " defaultOutputModes=ReimbursementAgent.SUPPORTED_CONTENT_TYPES,\n", + " capabilities=capabilities,\n", + " skills=[skill],\n", + ")\n", + "request_handler = DefaultRequestHandler(\n", + " agent_executor=ReimbursementAgentExecutor(),\n", + " task_store=InMemoryTaskStore(),\n", + ")\n", + "server = A2AStarletteApplication(\n", + " agent_card=agent_card, http_handler=request_handler\n", + ")\n", + "\n", + "# Build the Starlette ASGI app\n", + "# This `starlette_app` can be served by Uvicorn or used with TestClient\n", + "expense_starlette_app = server.build()" + ], + "metadata": { + "id": "DMfa8tTz9C3x" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Basic logging setup (helpful for seeing what the handler does)\n", + "logging.basicConfig(level=logging.INFO)\n", + "logger = logging.getLogger(__name__)" + ], + "metadata": { + "id": "VrX5UlCHz-vY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# TestClient should be used as a context manager or closed explicitly\n", + "with TestClient(expense_starlette_app) as client:\n", + " logger.info(\"\\n--- Test 1: Get Agent Card ---\")\n", + " response = client.get(\"/.well-known/agent.json\")\n", + " assert response.status_code == 200\n", + " agent_card_data = response.json()\n", + " print(f\"SUCCESS: Agent Card received: {agent_card_data['name']}\")\n", + " print(\"A2AClient initialized.\")\n", + "\n", + " print(\"\\n--- Quick Test : Non-streaming RPC - message/send ---\")\n", + " message_id_send = \"colab-msg-007\"\n", + " rpc_request_send_msg = {\n", + " \"jsonrpc\": \"2.0\",\n", + " \"id\": \"colab-req-send-msg-1\",\n", + " \"method\": \"message/send\",\n", + " \"params\": {\n", + " \"message\": {\n", + " \"role\": \"user\",\n", + " \"parts\": [{\"kind\": \"text\", \"text\": \"Hello Agent, Please reimburse me $20 for my lunch with the clients on 06/01/2025?\"}], # good one\n", + " \"messageId\": message_id_send,\n", + " \"kind\": \"message\",\n", + " \"contextId\": \"colab-session-xyz\"\n", + " }\n", + " }\n", + " }\n", + " response = client.post(\"/\", json=rpc_request_send_msg)\n", + " assert response.status_code == 200\n", + " rpc_response_send_msg = response.json()\n", + " print(f\"message/send response: {json.dumps(rpc_response_send_msg, indent=2)}\")\n", + " print(f\"SUCCESS: message/send for '{message_id_send}' passed.\")" + ], + "metadata": { + "id": "3MgWEUam8VN4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "IPi0qgrdVzl2" + } + }, + { + "cell_type": "markdown", + "source": [ + "### Define Eval helper functions\n", + "\n", + "Initiate a set of helper functions to print tutorial results." + ], + "metadata": { + "id": "zgD2FaExpTXH" + } + }, + { + "cell_type": "code", + "source": [ + "#@title imports\n", + "import json\n", + "\n", + "# General\n", + "import random\n", + "import string\n", + "from typing import Any\n", + "\n", + "from IPython.display import HTML, Markdown, display\n", + "from google.adk.agents import Agent\n", + "\n", + "# Build agent with adk\n", + "from google.adk.events import Event\n", + "from google.adk.runners import Runner, InMemoryRunner\n", + "from google.adk.sessions import InMemorySessionService\n", + "\n", + "# Evaluate agent\n", + "from google.cloud import aiplatform\n", + "from google.genai import types\n", + "import pandas as pd\n", + "import plotly.graph_objects as go\n", + "from vertexai.preview.evaluation import EvalTask\n", + "from vertexai.preview.evaluation.metrics import (\n", + " PointwiseMetric,\n", + " PointwiseMetricPromptTemplate,\n", + " TrajectorySingleToolUse,\n", + ")" + ], + "metadata": { + "id": "xXD6nXe9qKLt", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title helper functions\n", + "def get_id(length: int = 8) -> str:\n", + " \"\"\"Generate a uuid of a specified length (default=8).\"\"\"\n", + " return \"\".join(random.choices(string.ascii_lowercase + string.digits, k=length))\n", + "\n", + "def parse_a2a_output_to_dictionary(rpc_response_send_msg:dict) -> dict[str, Any]:\n", + " \"\"\"\n", + " Parse ADK event output into a structured dictionary format,\n", + " with the predicted trajectory dumped as a JSON string.\n", + "\n", + " \"\"\"\n", + "\n", + " final_response = \"\"\n", + " predicted_trajectory_list = []\n", + "\n", + " if (\n", + " \"result\" in rpc_response_send_msg\n", + " and \"artifacts\" in rpc_response_send_msg[\"result\"]\n", + " ):\n", + " for artifact in rpc_response_send_msg[\"result\"][\"artifacts\"]:\n", + " if artifact and \"parts\" in artifact:\n", + " for part in artifact[\"parts\"]:\n", + " if \"kind\" in part and part[\"kind\"] == \"text\" and \"text\" in part:\n", + " final_response = part[\"text\"]\n", + "\n", + " if final_response == \"\":\n", + " state = \"\"\n", + " if \"result\" in rpc_response_send_msg and \"status\" in rpc_response_send_msg[\"result\"]:\n", + " state = rpc_response_send_msg[\"result\"][\"status\"][\"state\"]\n", + " final_response = state\n", + " # Dump the collected trajectory list into a JSON string\n", + " final_output = {\n", + " \"response\": str(final_response),\n", + " \"predicted_trajectory\": json.dumps(predicted_trajectory_list),\n", + " }\n", + " return final_output\n", + "\n", + "\n", + "def parse_adk_output_to_dictionary(events: list[Event]) -> dict[str, Any]:\n", + " \"\"\"\n", + " Parse ADK event output into a structured dictionary format,\n", + " with the predicted trajectory dumped as a JSON string.\n", + "\n", + " \"\"\"\n", + "\n", + " final_response = \"\"\n", + " predicted_trajectory_list = []\n", + "\n", + " for event in events:\n", + " # Ensure content and parts exist before accessing them\n", + " if not event.content or not event.content.parts:\n", + " continue\n", + "\n", + " # Iterate through ALL parts in the event's content\n", + " for part in event.content.parts:\n", + " if part.function_call:\n", + " tool_info = {\n", + " \"tool_name\": part.function_call.name,\n", + " \"tool_input\": dict(part.function_call.args),\n", + " }\n", + " # Ensure we don't add duplicates if the same call appears somehow\n", + " if tool_info not in predicted_trajectory_list:\n", + " predicted_trajectory_list.append(tool_info)\n", + "\n", + " # The final text response is usually in the last event from the model\n", + " if event.content.role == \"model\" and part.text:\n", + " # Overwrite response; the last text response found is likely the final one\n", + " final_response = part.text.strip()\n", + "\n", + " # Dump the collected trajectory list into a JSON string\n", + " final_output = {\n", + " \"response\": str(final_response),\n", + " \"predicted_trajectory\": json.dumps(predicted_trajectory_list),\n", + " }\n", + "\n", + " return final_output\n", + "\n", + "\n", + "def format_output_as_markdown(output: dict) -> str:\n", + " \"\"\"Convert the output dictionary to a formatted markdown string.\"\"\"\n", + " markdown = \"### AI Response\\n\"\n", + " markdown += f\"{output['response']}\\n\\n\"\n", + "\n", + " if output[\"predicted_trajectory\"]:\n", + " output[\"predicted_trajectory\"] = json.loads(output[\"predicted_trajectory\"])\n", + " markdown += \"### Function Calls\\n\"\n", + " for call in output[\"predicted_trajectory\"]:\n", + " markdown += f\"- **Function**: `{call['tool_name']}`\\n\"\n", + " markdown += \" - **Arguments**:\\n\"\n", + " for key, value in call[\"tool_input\"].items():\n", + " markdown += f\" - `{key}`: `{value}`\\n\"\n", + "\n", + " return markdown\n", + "\n", + "\n", + "def display_eval_report(eval_result: pd.DataFrame) -> None:\n", + " \"\"\"Display the evaluation results.\"\"\"\n", + " metrics_df = pd.DataFrame.from_dict(eval_result.summary_metrics, orient=\"index\").T\n", + " display(Markdown(\"### Summary Metrics\"))\n", + " display(metrics_df)\n", + "\n", + " display(Markdown(\"### Row-wise Metrics\"))\n", + " display(eval_result.metrics_table)\n", + "\n", + "\n", + "def display_drilldown(row: pd.Series) -> None:\n", + " \"\"\"Displays a drill-down view for trajectory data within a row.\"\"\"\n", + "\n", + " style = \"white-space: pre-wrap; width: 800px; overflow-x: auto;\"\n", + "\n", + " if not (\n", + " isinstance(row[\"predicted_trajectory\"], list)\n", + " and isinstance(row[\"reference_trajectory\"], list)\n", + " ):\n", + " return\n", + "\n", + " for predicted_trajectory, reference_trajectory in zip(\n", + " row[\"predicted_trajectory\"], row[\"reference_trajectory\"]\n", + " ):\n", + " display(\n", + " HTML(\n", + " f\"

Tool Names:

{predicted_trajectory['tool_name'], reference_trajectory['tool_name']}
\"\n", + " )\n", + " )\n", + "\n", + " if not (\n", + " isinstance(predicted_trajectory.get(\"tool_input\"), dict)\n", + " and isinstance(reference_trajectory.get(\"tool_input\"), dict)\n", + " ):\n", + " continue\n", + "\n", + " for tool_input_key in predicted_trajectory[\"tool_input\"]:\n", + " print(\"Tool Input Key: \", tool_input_key)\n", + "\n", + " if tool_input_key in reference_trajectory[\"tool_input\"]:\n", + " print(\n", + " \"Tool Values: \",\n", + " predicted_trajectory[\"tool_input\"][tool_input_key],\n", + " reference_trajectory[\"tool_input\"][tool_input_key],\n", + " )\n", + " else:\n", + " print(\n", + " \"Tool Values: \",\n", + " predicted_trajectory[\"tool_input\"][tool_input_key],\n", + " \"N/A\",\n", + " )\n", + " print(\"\\n\")\n", + " display(HTML(\"
\"))\n", + "\n", + "\n", + "def display_dataframe_rows(\n", + " df: pd.DataFrame,\n", + " columns: list[str] | None = None,\n", + " num_rows: int = 3,\n", + " display_drilldown: bool = False,\n", + ") -> None:\n", + " \"\"\"Displays a subset of rows from a DataFrame, optionally including a drill-down view.\"\"\"\n", + "\n", + " if columns:\n", + " df = df[columns]\n", + "\n", + " base_style = \"font-family: monospace; font-size: 14px; white-space: pre-wrap; width: auto; overflow-x: auto;\"\n", + " header_style = base_style + \"font-weight: bold;\"\n", + "\n", + " for _, row in df.head(num_rows).iterrows():\n", + " for column in df.columns:\n", + " display(\n", + " HTML(\n", + " f\"{column.replace('_', ' ').title()}: \"\n", + " )\n", + " )\n", + " display(HTML(f\"{row[column]}
\"))\n", + "\n", + " display(HTML(\"
\"))\n", + "\n", + " if (\n", + " display_drilldown\n", + " and \"predicted_trajectory\" in df.columns\n", + " and \"reference_trajectory\" in df.columns\n", + " ):\n", + " display_drilldown(row)\n", + "\n", + "\n", + "def plot_bar_plot(\n", + " eval_result: pd.DataFrame, title: str, metrics: list[str] = None\n", + ") -> None:\n", + " fig = go.Figure()\n", + " data = []\n", + "\n", + " summary_metrics = eval_result.summary_metrics\n", + " if metrics:\n", + " summary_metrics = {\n", + " k: summary_metrics[k]\n", + " for k, v in summary_metrics.items()\n", + " if any(selected_metric in k for selected_metric in metrics)\n", + " }\n", + "\n", + " data.append(\n", + " go.Bar(\n", + " x=list(summary_metrics.keys()),\n", + " y=list(summary_metrics.values()),\n", + " name=title,\n", + " )\n", + " )\n", + "\n", + " fig = go.Figure(data=data)\n", + "\n", + " # Change the bar mode\n", + " fig.update_layout(barmode=\"group\")\n", + " fig.show()\n", + "\n", + "\n", + "def display_radar_plot(eval_results, title: str, metrics=None):\n", + " \"\"\"Plot the radar plot.\"\"\"\n", + " fig = go.Figure()\n", + " summary_metrics = eval_results.summary_metrics\n", + " if metrics:\n", + " summary_metrics = {\n", + " k: summary_metrics[k]\n", + " for k, v in summary_metrics.items()\n", + " if any(selected_metric in k for selected_metric in metrics)\n", + " }\n", + "\n", + " min_val = min(summary_metrics.values())\n", + " max_val = max(summary_metrics.values())\n", + "\n", + " fig.add_trace(\n", + " go.Scatterpolar(\n", + " r=list(summary_metrics.values()),\n", + " theta=list(summary_metrics.keys()),\n", + " fill=\"toself\",\n", + " name=title,\n", + " )\n", + " )\n", + " fig.update_layout(\n", + " title=title,\n", + " polar=dict(radialaxis=dict(visible=True, range=[min_val, max_val])),\n", + " showlegend=True,\n", + " )\n", + " fig.show()" + ], + "metadata": { + "id": "MC42rmmBVzWH", + "cellView": "form" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Assemble the A2A agents\n", + "\n", + "The Vertex AI Gen AI Evaluation works directly with 'Queryable' agents, and also lets you add your own custom functions with a specific structure (signature).\n", + "\n", + "In this case, you assemble the agent using a custom function. The function triggers the agent for a given input and parse the agent outcome to extract the response and called tools." + ], + "metadata": { + "id": "JlE_ivvmppk7" + } + }, + { + "cell_type": "code", + "source": [ + "def a2a_parsed_outcome(query):\n", + " # TestClient should be used as a context manager or closed explicitly\n", + " # query = \"Hello Agent, Please reimburse me $20 for my lunch with the clients on 06/01/2025?\"\n", + "\n", + " with TestClient(expense_starlette_app) as client:\n", + " print(\"\\n--- Get Agent Card ---\")\n", + " response = client.get(\"/.well-known/agent.json\")\n", + " assert response.status_code == 200\n", + " agent_card_data = response.json()\n", + " # assert agent_card_data[\"name\"] == MY_COLAB_AGENT_CARD.name\n", + " print(f\"--- SUCCESS: Agent Card received: {agent_card_data['name']} ---\")\n", + " print(\"--- A2AClient initialized. ---\")\n", + " print(f\"Query: {query}\")\n", + "\n", + " message_id_send = f\"colab-msg-{get_id()}\"\n", + " rpc_request_send_msg = {\n", + " \"jsonrpc\": \"2.0\",\n", + " \"id\": f\"colab-req-send-msg-{get_id()}\",\n", + " \"method\": \"message/send\",\n", + " \"params\": {\n", + " \"message\": {\n", + " \"role\": \"user\",\n", + " \"parts\": [{\"kind\": \"text\", \"text\": query}], # good one\n", + " \"messageId\": message_id_send,\n", + " \"kind\": \"message\",\n", + " \"contextId\": \"colab-session-xyz\"\n", + " }\n", + " }\n", + " }\n", + " response = client.post(\"/\", json=rpc_request_send_msg)\n", + " assert response.status_code == 200\n", + " rpc_response_send_msg = response.json()\n", + " print(f\"SUCCESS: message/send for '{message_id_send}' Finished\")\n", + " return parse_a2a_output_to_dictionary(rpc_response_send_msg)" + ], + "metadata": { + "id": "GumtiO-y4lUT" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Test the A2A agent\n", + "\n", + "Query your A2A agent with some quick examples." + ], + "metadata": { + "id": "wm1dLRwo59Wt" + } + }, + { + "cell_type": "code", + "source": [ + "response = a2a_parsed_outcome(query=\"Get product details for shoes\")\n", + "display(Markdown(format_output_as_markdown(response)))\n", + "\n", + "response = a2a_parsed_outcome(query=\"Hello Agent, Please reimburse me $20 for my lunch with the clients on 06/01/2025?\")\n", + "display(Markdown(format_output_as_markdown(response)))\n", + "\n", + "response = a2a_parsed_outcome(query=\"Hello Agent, Please reimburse me $311 for my flights from SFO to SEA on 06/11/2025?\")\n", + "display(Markdown(format_output_as_markdown(response)))\n", + "\n", + "response = a2a_parsed_outcome(query=\"Hello Agent, Please reimburse me $50 for my lunch with the clients on Jan 2nd,2024?\")\n", + "display(Markdown(format_output_as_markdown(response)))" + ], + "metadata": { + "id": "jugxifmk586y" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Prepare Agent Evaluation dataset\n", + "\n", + "To evaluate your AI agent using the Vertex AI Gen AI Evaluation service, you need a specific dataset depending on what aspects you want to evaluate of your agent. \n", + "\n", + "This dataset should include the prompts given to the agent. It can also contain the ideal or expected response (ground truth) and the intended sequence of tool calls the agent should take (reference trajectory) representing the sequence of tools you expect agent calls for each given prompt.\n", + "\n", + "> Optionally, you can provide both generated responses and predicted trajectory (**Bring-Your-Own-Dataset scenario**).\n", + "\n", + "Below you have an example of dataset you might have with a customer support agent with user prompt and the reference trajectory." + ], + "metadata": { + "id": "WChuA3Ip7bxg" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Define eval datasets\n", + "# The reference trajectory are empty in this example.\n", + "eval_data_a2a = {\n", + " \"prompt\": [\n", + " \"Get product details for shoes\",\n", + " \"Hello Agent, Please reimburse me $20 for my lunch with the clients on 06/01/2025?\",\n", + " \"Hello Agent, Please reimburse me $20 for my lunch with the clients\",\n", + " \"Please reimburse me $312 for my meal with the clients on 06/05/2025?\",\n", + " \"Please reimburse me $1234 for my flight to Seattle on 06/11/2025?\",\n", + " ],\n", + " \"reference_trajectory\": [\n", + " [],[],[],[],[],\n", + " ],\n", + "}\n", + "\n", + "eval_sample_dataset = pd.DataFrame(eval_data_a2a)" + ], + "metadata": { + "id": "f3ogOs-57czS" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "display_dataframe_rows(eval_sample_dataset, num_rows=30)" + ], + "metadata": { + "id": "WWgEqYP792zf" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Evaluate final response\n", + "\n", + "Similar to model evaluation, you can evaluate the final response of the agent using Vertex AI Gen AI Evaluation." + ], + "metadata": { + "id": "r7ApbDMh-FOm" + } + }, + { + "cell_type": "markdown", + "source": [ + "#### Set response metrics\n", + "\n", + "After agent inference, Vertex AI Gen AI Evaluation provides several metrics to evaluate generated responses. You can use computation-based metrics to compare the response to a reference (if needed) and using existing or custom model-based metrics to determine the quality of the final response.\n", + "\n", + "Check out the [documentation](https://cloud.google.com/vertex-ai/generative-ai/docs/models/determine-eval) to learn more.\n" + ], + "metadata": { + "id": "NkF4tDZE-KMD" + } + }, + { + "cell_type": "code", + "source": [ + "response_metrics = [\"safety\", \"coherence\"]" + ], + "metadata": { + "id": "GrPsJZqx-GNM" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "EXPERIMENT_RUN = f\"response-{get_id()}\"\n", + "\n", + "response_eval_task = EvalTask(\n", + " dataset=eval_sample_dataset,\n", + " metrics=response_metrics,\n", + " experiment=EXPERIMENT_NAME,\n", + " output_uri_prefix=BUCKET_URI + \"/response-metric-eval\",\n", + ")\n", + "\n", + "response_eval_result = response_eval_task.evaluate(\n", + " runnable=a2a_parsed_outcome, experiment_run_name=EXPERIMENT_RUN\n", + ")\n", + "\n", + "display_eval_report(response_eval_result)" + ], + "metadata": { + "id": "1uMp7XIo-Pp2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "#### Visualize evaluation results\n", + "\n", + "\n", + "Print new evaluation result sample." + ], + "metadata": { + "id": "yOgvurZpBx-d" + } + }, + { + "cell_type": "code", + "source": [ + "display_dataframe_rows(response_eval_result.metrics_table, num_rows=5)" + ], + "metadata": { + "id": "7S_0L9RUB2P3" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file