Skip to content

Commit e11f086

Browse files
fix for citations (#24)
* fix for citations * fix citations with files env variable * address comments
1 parent db4233b commit e11f086

File tree

3 files changed

+121
-56
lines changed

3 files changed

+121
-56
lines changed

src/api/main.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import os
77
import sys
8+
import json
89
from typing import Dict
910

1011
from azure.ai.projects.aio import AIProjectClient
@@ -56,10 +57,7 @@
5657

5758
@contextlib.asynccontextmanager
5859
async def lifespan(app: fastapi.FastAPI):
59-
files: Dict[str, Dict[str, str]] = {} # File name -> {"id": file_id, "path": file_path}
60-
vector_store = None
6160
agent = None
62-
create_new_agent = True
6361

6462
try:
6563
if not os.getenv("RUNNING_IN_PRODUCTION"):
@@ -89,33 +87,35 @@ async def lifespan(app: fastapi.FastAPI):
8987
if os.environ.get("AZURE_AI_AGENT_ID") is not None:
9088
try:
9189
agent = await ai_client.agents.get_agent(os.environ["AZURE_AI_AGENT_ID"])
92-
create_new_agent = False
9390
logger.info("Agent already exists, skipping creation")
9491
logger.info(f"Fetched agent, agent ID: {agent.id}")
9592
logger.info(f"Fetched agent, model name: {agent.model}")
9693
except Exception as e:
9794
logger.error(f"Error fetching agent: {e}", exc_info=True)
98-
create_new_agent = True
99-
if create_new_agent:
100-
# Check if a previous agent created by the template exists
95+
96+
if not agent:
97+
# Fallback to searching by name
98+
agent_name = os.environ["AZURE_AI_AGENT_NAME"]
10199
agent_list = await ai_client.agents.list_agents()
102100
if agent_list.data:
103101
for agent_object in agent_list.data:
104-
if agent_object.name == os.environ["AZURE_AI_AGENT_NAME"]:
102+
if agent_object.name == agent_name:
105103
agent = agent_object
106-
if agent == None:
107-
raise Exception("Agent not found")
104+
logger.info(f"Found agent by name '{agent_name}', ID={agent_object.id}")
105+
break
108106

109-
except Exception as e:
110-
logger.error(f"Error creating agent: {e}", exc_info=True)
111-
raise RuntimeError(f"Failed to create the agent: {e}")
107+
if not agent:
108+
raise RuntimeError("No agent found. Ensure qunicorn.py created one or set AZURE_AI_AGENT_ID.")
112109

113-
app.state.ai_client = ai_client
114-
app.state.agent = agent
115-
app.state.files = files
110+
app.state.ai_client = ai_client
111+
app.state.agent = agent
116112

117-
try:
118113
yield
114+
115+
except Exception as e:
116+
logger.error(f"Error during startup: {e}", exc_info=True)
117+
raise RuntimeError(f"Error during startup: {e}")
118+
119119
finally:
120120
try:
121121
await ai_client.close()

src/api/routes.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,15 @@ async def fetch_document(request: Request):
197197
if not file_name:
198198
raise HTTPException(status_code=400, detail="file_name is required")
199199

200-
files = getattr(request.app.state, "files", {})
200+
# Reconstruct the file dictionary from the env variable:
201+
files_env = os.environ['UPLOADED_FILE_MAP']
202+
try:
203+
files = json.loads(files_env)
204+
logger.info("Successfully parsed UPLOADED_FILE_MAP from environment variable.")
205+
except json.JSONDecodeError:
206+
files = {}
207+
logger.warning("Failed to parse UPLOADED_FILE_MAP from environment variable.", exc_info=True)
208+
201209
logger.info(f"File requested: {file_name}. Current file keys: {list(files.keys())}")
202210

203211
if file_name not in files:

src/gunicorn.conf.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import multiprocessing
22
import os
33
import sys
4+
import json
45
from typing import Dict
56
import asyncio
67
import logging
78
from azure.ai.projects.aio import AIProjectClient
8-
from azure.ai.projects.models import FilePurpose, FileSearchTool, AsyncToolSet
9+
from azure.ai.projects.models import FilePurpose, FileSearchTool, AsyncToolSet, Agent
910
from azure.identity import DefaultAzureCredential
1011

1112
from dotenv import load_dotenv
@@ -32,66 +33,122 @@
3233
file_handler.setFormatter(file_formatter)
3334
logger.addHandler(file_handler)
3435

35-
async def list_or_create_agent():
36-
files: Dict[str, Dict[str, str]] = {} # File name -> {"id": file_id, "path": file_path}
37-
vector_store = None
38-
agent = None
36+
FILES_NAMES = ["product_info_1.md", "product_info_2.md"]
3937

38+
async def create_agent(ai_client: AIProjectClient) -> Agent:
39+
files: Dict[str, Dict[str, str]] = {}
40+
41+
# Create a new agent with the required resources
42+
logger.info("Creating new agent with resources")
43+
44+
# Upload files for file search
45+
for file_name in FILES_NAMES:
46+
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'files', file_name))
47+
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
48+
# Store both file id and the file path using the file name as key.
49+
files[file_name] = {"id": file.id, "path": file_path}
50+
51+
# Serialize and store files information in the environment variable (so workers see it)
52+
os.environ["UPLOADED_FILE_MAP"] = json.dumps(files)
53+
logger.info(f"Set env UPLOADED_FILE_MAP = {os.environ['UPLOADED_FILE_MAP']}")
54+
55+
# Create the vector store using the file IDs.
56+
vector_store = await ai_client.agents.create_vector_store_and_poll(
57+
file_ids=[info["id"] for info in files.values()],
58+
name="sample_store"
59+
)
60+
logger.info("agent: file store and vector store success")
61+
62+
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
63+
toolset = AsyncToolSet()
64+
toolset.add(file_search_tool)
65+
66+
agent = await ai_client.agents.create_agent(
67+
model=os.environ["AZURE_AI_AGENT_DEPLOYMENT_NAME"],
68+
name=os.environ["AZURE_AI_AGENT_NAME"],
69+
instructions="You are helpful assistant",
70+
toolset=toolset
71+
)
72+
return agent
73+
74+
75+
async def update_agent(agent: Agent, ai_client: AIProjectClient) -> Agent:
76+
logger.info("Updating agent with resources")
77+
files: Dict[str, Dict[str, str]] = {}
78+
79+
# Upload files for file search
80+
for file_name in FILES_NAMES:
81+
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'files', file_name))
82+
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
83+
# Store both file id and the file path using the file name as key.
84+
files[file_name] = {"id": file.id, "path": file_path}
85+
86+
# Serialize and store files information in the environment variable (so workers see it)
87+
os.environ["UPLOADED_FILE_MAP"] = json.dumps(files)
88+
logger.info(f"Set env UPLOADED_FILE_MAP = {os.environ['UPLOADED_FILE_MAP']}")
89+
90+
# Create the vector store using the file IDs.
91+
vector_store = await ai_client.agents.create_vector_store_and_poll(
92+
file_ids=[info["id"] for info in files.values()],
93+
name="sample_store"
94+
)
95+
logger.info("agent: file store and vector store success")
96+
97+
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
98+
toolset = AsyncToolSet()
99+
toolset.add(file_search_tool)
100+
101+
agent = await ai_client.agents.update_agent(
102+
assistant_id=agent.id,
103+
model=os.environ["AZURE_AI_AGENT_DEPLOYMENT_NAME"],
104+
name=os.environ["AZURE_AI_AGENT_NAME"],
105+
instructions="You are helpful assistant",
106+
toolset=toolset
107+
)
108+
return agent
109+
110+
111+
async def initialize_resources():
40112
try:
41113
ai_client = AIProjectClient.from_connection_string(
42114
credential=DefaultAzureCredential(exclude_shared_token_cache_credential=True),
43115
conn_str=os.environ["AZURE_AIPROJECT_CONNECTION_STRING"],
44116
)
45117

46-
if os.environ.get("AZURE_AI_AGENT_ID"):
118+
# If the environment already has AZURE_AI_AGENT_ID, try fetching that agent
119+
if os.environ.get("AZURE_AI_AGENT_ID") is not None:
47120
try:
48121
agent = await ai_client.agents.get_agent(os.environ["AZURE_AI_AGENT_ID"])
122+
logger.info(f"Found agent by ID: {agent.id}")
123+
# Update the agent with the latest resources
124+
agent = await update_agent(agent, ai_client)
49125
return
50126
except Exception as e:
51-
logger.info("Error with agent ID")
127+
logger.warning(f"Could not retrieve agent by AZURE_AI_AGENT_ID = {os.environ['AZURE_AI_AGENT_ID']}, error: {e}")
52128

53-
# Check if a previous agent created by the template exists
129+
# Check if an agent with the same name already exists
54130
agent_list = await ai_client.agents.list_agents()
55131
if agent_list.data:
56132
for agent_object in agent_list.data:
57133
if agent_object.name == os.environ["AZURE_AI_AGENT_NAME"]:
58-
return
59-
60-
# Create a new agent with the required resources
61-
logger.info("Creating new agent with resources")
62-
file_names = ["product_info_1.md", "product_info_2.md"]
63-
for file_name in file_names:
64-
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'files', file_name))
65-
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
66-
# Store both file id and the file path using the file name as key.
67-
files[file_name] = {"id": file.id, "path": file_path}
68-
69-
# Create the vector store using the file IDs.
70-
vector_store = await ai_client.agents.create_vector_store_and_poll(
71-
file_ids=[info["id"] for info in files.values()],
72-
name="sample_store"
73-
)
74-
logger.info("agent: file store and vector store success")
134+
logger.info(f"Found existing agent named '{agent_object.name}', ID: {agent_object.id}")
135+
os.environ["AZURE_AI_AGENT_ID"] = agent_object.id
136+
# Update the agent with the latest resources
137+
agent = await update_agent(agent_object, ai_client)
138+
return
75139

76-
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
77-
toolset = AsyncToolSet()
78-
toolset.add(file_search_tool)
140+
# Create a new agent
141+
agent = await create_agent(ai_client)
142+
os.environ["AZURE_AI_AGENT_ID"] = agent.id
143+
logger.info(f"Created agent, agent ID: {agent.id}")
79144

80-
agent = await ai_client.agents.create_agent(
81-
model=os.environ["AZURE_AI_AGENT_DEPLOYMENT_NAME"],
82-
name=os.environ["AZURE_AI_AGENT_NAME"],
83-
instructions="You are helpful assistant",
84-
toolset=toolset
85-
)
86-
logger.info("Created agent, agent ID: {agent.id}")
87-
88145
except Exception as e:
89146
logger.info("Error creating agent: {e}", exc_info=True)
90147
raise RuntimeError(f"Failed to create the agent: {e}")
91-
148+
92149
def on_starting(server):
93150
"""This code runs once before the workers will start."""
94-
asyncio.get_event_loop().run_until_complete(list_or_create_agent())
151+
asyncio.get_event_loop().run_until_complete(initialize_resources())
95152

96153
max_requests = 1000
97154
max_requests_jitter = 50

0 commit comments

Comments
 (0)