|
1 | 1 | import multiprocessing
|
2 | 2 | import os
|
3 | 3 | import sys
|
| 4 | +import json |
4 | 5 | from typing import Dict
|
5 | 6 | import asyncio
|
6 | 7 | import logging
|
7 | 8 | from azure.ai.projects.aio import AIProjectClient
|
8 |
| -from azure.ai.projects.models import FilePurpose, FileSearchTool, AsyncToolSet |
| 9 | +from azure.ai.projects.models import FilePurpose, FileSearchTool, AsyncToolSet, Agent |
9 | 10 | from azure.identity import DefaultAzureCredential
|
10 | 11 |
|
11 | 12 | from dotenv import load_dotenv
|
|
32 | 33 | file_handler.setFormatter(file_formatter)
|
33 | 34 | logger.addHandler(file_handler)
|
34 | 35 |
|
35 |
| -async def list_or_create_agent(): |
36 |
| - files: Dict[str, Dict[str, str]] = {} # File name -> {"id": file_id, "path": file_path} |
37 |
| - vector_store = None |
38 |
| - agent = None |
| 36 | +FILES_NAMES = ["product_info_1.md", "product_info_2.md"] |
39 | 37 |
|
| 38 | +async def create_agent(ai_client: AIProjectClient) -> Agent: |
| 39 | + files: Dict[str, Dict[str, str]] = {} |
| 40 | + |
| 41 | + # Create a new agent with the required resources |
| 42 | + logger.info("Creating new agent with resources") |
| 43 | + |
| 44 | + # Upload files for file search |
| 45 | + for file_name in FILES_NAMES: |
| 46 | + file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'files', file_name)) |
| 47 | + file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS) |
| 48 | + # Store both file id and the file path using the file name as key. |
| 49 | + files[file_name] = {"id": file.id, "path": file_path} |
| 50 | + |
| 51 | + # Serialize and store files information in the environment variable (so workers see it) |
| 52 | + os.environ["UPLOADED_FILE_MAP"] = json.dumps(files) |
| 53 | + logger.info(f"Set env UPLOADED_FILE_MAP = {os.environ['UPLOADED_FILE_MAP']}") |
| 54 | + |
| 55 | + # Create the vector store using the file IDs. |
| 56 | + vector_store = await ai_client.agents.create_vector_store_and_poll( |
| 57 | + file_ids=[info["id"] for info in files.values()], |
| 58 | + name="sample_store" |
| 59 | + ) |
| 60 | + logger.info("agent: file store and vector store success") |
| 61 | + |
| 62 | + file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id]) |
| 63 | + toolset = AsyncToolSet() |
| 64 | + toolset.add(file_search_tool) |
| 65 | + |
| 66 | + agent = await ai_client.agents.create_agent( |
| 67 | + model=os.environ["AZURE_AI_AGENT_DEPLOYMENT_NAME"], |
| 68 | + name=os.environ["AZURE_AI_AGENT_NAME"], |
| 69 | + instructions="You are helpful assistant", |
| 70 | + toolset=toolset |
| 71 | + ) |
| 72 | + return agent |
| 73 | + |
| 74 | + |
| 75 | +async def update_agent(agent: Agent, ai_client: AIProjectClient) -> Agent: |
| 76 | + logger.info("Updating agent with resources") |
| 77 | + files: Dict[str, Dict[str, str]] = {} |
| 78 | + |
| 79 | + # Upload files for file search |
| 80 | + for file_name in FILES_NAMES: |
| 81 | + file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'files', file_name)) |
| 82 | + file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS) |
| 83 | + # Store both file id and the file path using the file name as key. |
| 84 | + files[file_name] = {"id": file.id, "path": file_path} |
| 85 | + |
| 86 | + # Serialize and store files information in the environment variable (so workers see it) |
| 87 | + os.environ["UPLOADED_FILE_MAP"] = json.dumps(files) |
| 88 | + logger.info(f"Set env UPLOADED_FILE_MAP = {os.environ['UPLOADED_FILE_MAP']}") |
| 89 | + |
| 90 | + # Create the vector store using the file IDs. |
| 91 | + vector_store = await ai_client.agents.create_vector_store_and_poll( |
| 92 | + file_ids=[info["id"] for info in files.values()], |
| 93 | + name="sample_store" |
| 94 | + ) |
| 95 | + logger.info("agent: file store and vector store success") |
| 96 | + |
| 97 | + file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id]) |
| 98 | + toolset = AsyncToolSet() |
| 99 | + toolset.add(file_search_tool) |
| 100 | + |
| 101 | + agent = await ai_client.agents.update_agent( |
| 102 | + assistant_id=agent.id, |
| 103 | + model=os.environ["AZURE_AI_AGENT_DEPLOYMENT_NAME"], |
| 104 | + name=os.environ["AZURE_AI_AGENT_NAME"], |
| 105 | + instructions="You are helpful assistant", |
| 106 | + toolset=toolset |
| 107 | + ) |
| 108 | + return agent |
| 109 | + |
| 110 | + |
| 111 | +async def initialize_resources(): |
40 | 112 | try:
|
41 | 113 | ai_client = AIProjectClient.from_connection_string(
|
42 | 114 | credential=DefaultAzureCredential(exclude_shared_token_cache_credential=True),
|
43 | 115 | conn_str=os.environ["AZURE_AIPROJECT_CONNECTION_STRING"],
|
44 | 116 | )
|
45 | 117 |
|
46 |
| - if os.environ.get("AZURE_AI_AGENT_ID"): |
| 118 | + # If the environment already has AZURE_AI_AGENT_ID, try fetching that agent |
| 119 | + if os.environ.get("AZURE_AI_AGENT_ID") is not None: |
47 | 120 | try:
|
48 | 121 | agent = await ai_client.agents.get_agent(os.environ["AZURE_AI_AGENT_ID"])
|
| 122 | + logger.info(f"Found agent by ID: {agent.id}") |
| 123 | + # Update the agent with the latest resources |
| 124 | + agent = await update_agent(agent, ai_client) |
49 | 125 | return
|
50 | 126 | except Exception as e:
|
51 |
| - logger.info("Error with agent ID") |
| 127 | + logger.warning(f"Could not retrieve agent by AZURE_AI_AGENT_ID = {os.environ['AZURE_AI_AGENT_ID']}, error: {e}") |
52 | 128 |
|
53 |
| - # Check if a previous agent created by the template exists |
| 129 | + # Check if an agent with the same name already exists |
54 | 130 | agent_list = await ai_client.agents.list_agents()
|
55 | 131 | if agent_list.data:
|
56 | 132 | for agent_object in agent_list.data:
|
57 | 133 | if agent_object.name == os.environ["AZURE_AI_AGENT_NAME"]:
|
58 |
| - return |
59 |
| - |
60 |
| - # Create a new agent with the required resources |
61 |
| - logger.info("Creating new agent with resources") |
62 |
| - file_names = ["product_info_1.md", "product_info_2.md"] |
63 |
| - for file_name in file_names: |
64 |
| - file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'files', file_name)) |
65 |
| - file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS) |
66 |
| - # Store both file id and the file path using the file name as key. |
67 |
| - files[file_name] = {"id": file.id, "path": file_path} |
68 |
| - |
69 |
| - # Create the vector store using the file IDs. |
70 |
| - vector_store = await ai_client.agents.create_vector_store_and_poll( |
71 |
| - file_ids=[info["id"] for info in files.values()], |
72 |
| - name="sample_store" |
73 |
| - ) |
74 |
| - logger.info("agent: file store and vector store success") |
| 134 | + logger.info(f"Found existing agent named '{agent_object.name}', ID: {agent_object.id}") |
| 135 | + os.environ["AZURE_AI_AGENT_ID"] = agent_object.id |
| 136 | + # Update the agent with the latest resources |
| 137 | + agent = await update_agent(agent_object, ai_client) |
| 138 | + return |
75 | 139 |
|
76 |
| - file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id]) |
77 |
| - toolset = AsyncToolSet() |
78 |
| - toolset.add(file_search_tool) |
| 140 | + # Create a new agent |
| 141 | + agent = await create_agent(ai_client) |
| 142 | + os.environ["AZURE_AI_AGENT_ID"] = agent.id |
| 143 | + logger.info(f"Created agent, agent ID: {agent.id}") |
79 | 144 |
|
80 |
| - agent = await ai_client.agents.create_agent( |
81 |
| - model=os.environ["AZURE_AI_AGENT_DEPLOYMENT_NAME"], |
82 |
| - name=os.environ["AZURE_AI_AGENT_NAME"], |
83 |
| - instructions="You are helpful assistant", |
84 |
| - toolset=toolset |
85 |
| - ) |
86 |
| - logger.info("Created agent, agent ID: {agent.id}") |
87 |
| - |
88 | 145 | except Exception as e:
|
89 | 146 | logger.info("Error creating agent: {e}", exc_info=True)
|
90 | 147 | raise RuntimeError(f"Failed to create the agent: {e}")
|
91 |
| - |
| 148 | + |
92 | 149 | def on_starting(server):
|
93 | 150 | """This code runs once before the workers will start."""
|
94 |
| - asyncio.get_event_loop().run_until_complete(list_or_create_agent()) |
| 151 | + asyncio.get_event_loop().run_until_complete(initialize_resources()) |
95 | 152 |
|
96 | 153 | max_requests = 1000
|
97 | 154 | max_requests_jitter = 50
|
|
0 commit comments