Skip to content

Commit db53e99

Browse files
Merge pull request #2 from sophia-ramsey/jhakulin/updates-1
Jhakulin/updates 1
2 parents efd264a + c330a98 commit db53e99

File tree

7 files changed

+387
-259
lines changed

7 files changed

+387
-259
lines changed

src/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ RUN pip install --no-cache-dir --upgrade -r requirements.txt
1010

1111
EXPOSE 50505
1212

13-
CMD ["gunicorn", "api.main:create_app()"]
13+
CMD ["gunicorn", "api.main:create_app"]

src/api/config_helper.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

src/api/main.py

Lines changed: 108 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,113 +1,133 @@
1+
# Copyright (c) Microsoft. All rights reserved.
2+
# Licensed under the MIT license. See LICENSE.md file in the project root for full license information.
3+
14
import contextlib
25
import logging
36
import os
7+
import sys
8+
from typing import Dict
49

5-
import fastapi
6-
from azure.ai.projects.aio import AIProjectClient
7-
from dotenv import load_dotenv
8-
from fastapi.staticfiles import StaticFiles
9-
10-
11-
from typing import AsyncGenerator, Dict, Optional, Tuple
12-
13-
14-
import os
1510
from azure.ai.projects.aio import AIProjectClient
11+
from azure.ai.projects.models import FilePurpose, FileSearchTool, AsyncToolSet
1612
from azure.identity import DefaultAzureCredential
1713

18-
from azure.ai.projects.models import (
19-
MessageDeltaChunk,
20-
ThreadMessage,
21-
FileSearchTool,
22-
AsyncToolSet,
23-
FilePurpose,
24-
ThreadMessage,
25-
StreamEventData,
26-
AsyncAgentEventHandler,
27-
Agent,
28-
VectorStore
29-
)
30-
31-
from .shared import bp
32-
33-
14+
import fastapi
15+
from fastapi.staticfiles import StaticFiles
16+
from fastapi import Request
17+
from fastapi.responses import JSONResponse
18+
from dotenv import load_dotenv
3419

20+
# Create a central logger for the application
3521
logger = logging.getLogger("azureaiapp")
3622
logger.setLevel(logging.INFO)
3723

24+
# Configure the stream handler (stdout)
25+
stream_handler = logging.StreamHandler(sys.stdout)
26+
stream_handler.setLevel(logging.INFO)
27+
stream_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
28+
stream_handler.setFormatter(stream_formatter)
29+
logger.addHandler(stream_handler)
3830

39-
@contextlib.asynccontextmanager
40-
async def lifespan(app: fastapi.FastAPI):
41-
42-
43-
44-
ai_client = AIProjectClient.from_connection_string(
45-
credential=DefaultAzureCredential(exclude_shared_token_cache_credential=True),
46-
conn_str=os.environ["AZURE_AIPROJECT_CONNECTION_STRING"],
47-
)
48-
49-
# TODO: add more files are not supported for citation at the moment
50-
file_names = ["product_info_1.md", "product_info_2.md"]
51-
files: Dict[str, str] = {}
52-
for file_name in file_names:
53-
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file_name))
54-
print(f"Uploading file {file_path}")
55-
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
56-
files.update({file.id: file_path})
57-
58-
vector_store = await ai_client.agents.create_vector_store_and_poll(file_ids=list(files.keys()), name="sample_store")
59-
60-
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
61-
62-
tool_set = AsyncToolSet()
63-
tool_set.add(file_search_tool)
64-
65-
print(f"ToolResource: {tool_set.resources}")
66-
67-
agent = await ai_client.agents.create_agent(
68-
model="gpt-4o-mini", name="my-assistant", instructions="You are helpful assistant", tools = tool_set.definitions, tool_resources=tool_set.resources
69-
)
70-
71-
print(f"Created agent, agent ID: {agent.id}")
31+
# Configure the file handler
32+
log_file_path = os.getenv("APP_LOG_FILE", "app.log")
33+
file_handler = logging.FileHandler(log_file_path)
34+
file_handler.setLevel(logging.INFO)
35+
file_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s")
36+
file_handler.setFormatter(file_formatter)
37+
logger.addHandler(file_handler)
7238

73-
bp.ai_client = ai_client
74-
bp.agent = agent
75-
bp.vector_store = vector_store
76-
bp.files = files
77-
78-
yield
7939

80-
await stop_server()
40+
@contextlib.asynccontextmanager
41+
async def lifespan(app: fastapi.FastAPI):
42+
files: Dict[str, Dict[str, str]] = {} # File name -> {"id": file_id, "path": file_path}
43+
vector_store = None
44+
agent = None
45+
46+
try:
47+
if not os.getenv("RUNNING_IN_PRODUCTION"):
48+
logger.info("Loading .env file")
49+
load_dotenv(override=True)
50+
51+
ai_client = AIProjectClient.from_connection_string(
52+
credential=DefaultAzureCredential(exclude_shared_token_cache_credential=True),
53+
conn_str=os.environ["AZURE_AIPROJECT_CONNECTION_STRING"],
54+
)
55+
logger.info("Created AIProjectClient")
56+
57+
file_names = ["product_info_1.md", "product_info_2.md"]
58+
for file_name in file_names:
59+
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'files', file_name))
60+
file = await ai_client.agents.upload_file_and_poll(file_path=file_path, purpose=FilePurpose.AGENTS)
61+
logger.info(f"Uploaded file {file_path}, file ID: {file.id}")
62+
# Store both file id and the file path using the file name as key.
63+
files[file_name] = {"id": file.id, "path": file_path}
8164

82-
83-
async def stop_server():
84-
for file_id in bp.files.keys():
85-
await bp.ai_client.agents.delete_file(file_id)
86-
print(f"Deleted file {file_id}")
87-
88-
await bp.ai_client.agents.delete_vector_store(bp.vector_store.id)
89-
print(f"Deleted vector store {bp.vector_store.id}")
90-
91-
await bp.ai_client.agents.delete_agent(bp.agent.id)
92-
93-
print(f"Deleted agent {bp.agent.id}")
94-
95-
await bp.ai_client.close()
96-
print("Closed AIProjectClient")
97-
65+
# Create the vector store using the file IDs.
66+
vector_store = await ai_client.agents.create_vector_store_and_poll(
67+
file_ids=[info["id"] for info in files.values()],
68+
name="sample_store"
69+
)
70+
71+
file_search_tool = FileSearchTool(vector_store_ids=[vector_store.id])
72+
toolset = AsyncToolSet()
73+
toolset.add(file_search_tool)
74+
75+
agent = await ai_client.agents.create_agent(
76+
model="gpt-4o-mini",
77+
name="my-assistant",
78+
instructions="You are helpful assistant",
79+
toolset=toolset
80+
)
81+
logger.info(f"Created agent, agent ID: {agent.id}")
82+
except Exception as e:
83+
logger.error(f"Error creating agent: {e}", exc_info=True)
84+
raise RuntimeError(f"Failed to create the agent: {e}")
85+
86+
app.state.ai_client = ai_client
87+
app.state.agent = agent
88+
app.state.files = files
89+
90+
try:
91+
yield
92+
finally:
93+
# Cleanup on shutdown.
94+
try:
95+
for info in files.values():
96+
await ai_client.agents.delete_file(info["id"])
97+
logger.info(f"Deleted file {info['id']}")
98+
99+
if vector_store:
100+
await ai_client.agents.delete_vector_store(vector_store.id)
101+
logger.info(f"Deleted vector store {vector_store.id}")
102+
103+
if agent:
104+
await ai_client.agents.delete_agent(agent.id)
105+
logger.info(f"Deleted agent {agent.id}")
106+
except Exception as e:
107+
logger.error(f"Error during cleanup: {e}", exc_info=True)
108+
109+
try:
110+
await ai_client.close()
111+
logger.info("Closed AIProjectClient")
112+
except Exception as e:
113+
logger.error("Error closing AIProjectClient", exc_info=True)
98114

99115

100116
def create_app():
101-
if not os.getenv("RUNNING_IN_PRODUCTION"):
102-
logger.info("Loading .env file")
103-
load_dotenv(override=True)
104-
105117
directory = os.path.join(os.path.dirname(__file__), "static")
106118
app = fastapi.FastAPI(lifespan=lifespan)
107119
app.mount("/static", StaticFiles(directory=directory), name="static")
108120

109-
from . import routes # noqa
110-
121+
from . import routes # Import routes
111122
app.include_router(routes.router)
112123

124+
# Global exception handler for any unhandled exceptions
125+
@app.exception_handler(Exception)
126+
async def global_exception_handler(request: Request, exc: Exception):
127+
logger.error("Unhandled exception occurred", exc_info=exc)
128+
return JSONResponse(
129+
status_code=500,
130+
content={"detail": "Internal server error"}
131+
)
132+
113133
return app

0 commit comments

Comments
 (0)