Skip to content

Commit 28767b3

Browse files
committed
Added simple backend agent api
Includes in-memory conversation history
1 parent 4db9cbd commit 28767b3

File tree

4 files changed

+224
-1
lines changed

4 files changed

+224
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
AIRequest model
3+
"""
4+
from pydantic import BaseModel
5+
6+
class AIRequest(BaseModel):
7+
"""
8+
AIRequest model encapsulates the session_id
9+
and incoming user prompt for the AI agent
10+
to respond to.
11+
"""
12+
session_id: str
13+
prompt: str
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
The CosmicWorksAIAgent class encapsulates a LangChain
3+
agent that can be used to answer questions about Cosmic Works
4+
products, customers, and sales.
5+
"""
6+
import os
7+
import json
8+
from typing import List
9+
import pymongo
10+
from dotenv import load_dotenv
11+
from langchain.chat_models import AzureChatOpenAI
12+
from langchain.embeddings import AzureOpenAIEmbeddings
13+
from langchain.vectorstores.azure_cosmos_db import AzureCosmosDBVectorSearch
14+
from langchain.schema.document import Document
15+
from langchain.agents import Tool
16+
from langchain.agents.agent_toolkits import create_conversational_retrieval_agent
17+
from langchain.tools import StructuredTool
18+
from langchain_core.messages import SystemMessage
19+
20+
load_dotenv("../../.env")
21+
DB_CONNECTION_STRING = os.environ.get("DB_CONNECTION_STRING")
22+
db = pymongo.MongoClient(DB_CONNECTION_STRING).cosmic_works
23+
24+
class CosmicWorksAIAgent:
25+
"""
26+
The CosmicWorksAIAgent class creates Cosmo, an AI agent
27+
that can be used to answer questions about Cosmic Works
28+
products, customers, and sales.
29+
"""
30+
def __init__(self, session_id: str):
31+
llm = AzureChatOpenAI(
32+
temperature = 0,
33+
openai_api_version = "2023-09-01-preview",
34+
azure_endpoint = os.environ.get("AOAI_ENDPOINT"),
35+
openai_api_key = os.environ.get("AOAI_KEY"),
36+
azure_deployment = "completions"
37+
)
38+
self.embedding_model = AzureOpenAIEmbeddings(
39+
openai_api_version = "2023-09-01-preview",
40+
azure_endpoint = os.environ.get("AOAI_ENDPOINT"),
41+
openai_api_key = os.environ.get("AOAI_KEY"),
42+
azure_deployment = "embeddings",
43+
chunk_size=10
44+
)
45+
system_message = SystemMessage(
46+
content = """
47+
You are a helpful, fun and friendly sales assistant for Cosmic Works,
48+
a bicycle and bicycle accessories store.
49+
50+
Your name is Cosmo.
51+
52+
You are designed to answer questions about the products that Cosmic Works sells,
53+
the customers that buy them, and the sales orders that are placed by customers.
54+
55+
If you don't know the answer to a question, respond with "I don't know."
56+
"""
57+
)
58+
self.agent_executor = create_conversational_retrieval_agent(
59+
llm,
60+
self.__create_agent_tools(),
61+
system_message = system_message,
62+
memory_key=session_id,
63+
verbose=True
64+
)
65+
66+
def run(self, prompt: str) -> str:
67+
"""
68+
Run the AI agent.
69+
"""
70+
result = self.agent_executor({"input": prompt})
71+
return result["output"]
72+
73+
def __create_cosmic_works_vector_store_retriever(
74+
self,
75+
collection_name: str,
76+
top_k: int = 3
77+
):
78+
"""
79+
Returns a vector store retriever for the given collection.
80+
"""
81+
vector_store = AzureCosmosDBVectorSearch.from_connection_string(
82+
connection_string = DB_CONNECTION_STRING,
83+
namespace = f"cosmic_works.{collection_name}",
84+
embedding = self.embedding_model,
85+
index_name = "VectorSearchIndex",
86+
embedding_key = "contentVector",
87+
text_key = "_id"
88+
)
89+
return vector_store.as_retriever(search_kwargs={"k": top_k})
90+
91+
def __create_agent_tools(self) -> List[Tool]:
92+
"""
93+
Returns a list of agent tools.
94+
"""
95+
products_retriever = self.__create_cosmic_works_vector_store_retriever("products")
96+
customers_retriever = self.__create_cosmic_works_vector_store_retriever("customers")
97+
sales_retriever = self.__create_cosmic_works_vector_store_retriever("sales")
98+
99+
# create a chain on the retriever to format the documents as JSON
100+
products_retriever_chain = products_retriever | format_docs
101+
customers_retriever_chain = customers_retriever | format_docs
102+
sales_retriever_chain = sales_retriever | format_docs
103+
104+
tools = [
105+
Tool(
106+
name = "vector_search_products",
107+
func = products_retriever_chain.invoke,
108+
description = """
109+
Searches Cosmic Works product information for similar products based
110+
on the question. Returns the product information in JSON format.
111+
"""
112+
),
113+
Tool(
114+
name = "vector_search_customers",
115+
func = customers_retriever_chain.invoke,
116+
description = """
117+
Searches Cosmic Works customer information and retrieves similar
118+
customers based on the question. Returns the customer information
119+
in JSON format.
120+
"""
121+
),
122+
Tool(
123+
name = "vector_search_sales",
124+
func = sales_retriever_chain.invoke,
125+
description = """
126+
Searches Cosmic Works customer sales information and retrieves sales order
127+
details based on the question. Returns the sales order information in JSON format.
128+
"""
129+
),
130+
StructuredTool.from_function(get_product_by_id),
131+
StructuredTool.from_function(get_product_by_sku),
132+
StructuredTool.from_function(get_sales_by_id)
133+
]
134+
return tools
135+
136+
def format_docs(docs:List[Document]) -> str:
137+
"""
138+
Prepares the product list for the system prompt.
139+
"""
140+
str_docs = []
141+
for doc in docs:
142+
# Build the product document without the contentVector
143+
doc_dict = {"_id": doc.page_content}
144+
doc_dict.update(doc.metadata)
145+
if "contentVector" in doc_dict:
146+
del doc_dict["contentVector"]
147+
str_docs.append(json.dumps(doc_dict, default=str))
148+
# Return a single string containing each product JSON representation
149+
# separated by two newlines
150+
return "\n\n".join(str_docs)
151+
152+
def get_product_by_id(product_id: str) -> str:
153+
"""
154+
Retrieves a product by its ID.
155+
"""
156+
doc = db.products.find_one({"_id": product_id})
157+
if "contentVector" in doc:
158+
del doc["contentVector"]
159+
return json.dumps(doc)
160+
161+
def get_product_by_sku(sku: str) -> str:
162+
"""
163+
Retrieves a product by its sku.
164+
"""
165+
doc = db.products.find_one({"sku": sku})
166+
if "contentVector" in doc:
167+
del doc["contentVector"]
168+
return json.dumps(doc, default=str)
169+
170+
def get_sales_by_id(sales_id: str) -> str:
171+
"""
172+
Retrieves a sales order by its ID.
173+
"""
174+
doc = db.sales.find_one({"_id": sales_id})
175+
if "contentVector" in doc:
176+
del doc["contentVector"]
177+
return json.dumps(doc, default=str)
178+

Labs/app/backend/main.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
API entrypoint for backend API.
3+
"""
4+
import uvicorn
5+
from fastapi import FastAPI
6+
from api_models.ai_request import AIRequest
7+
from cosmic_works.cosmic_works_ai_agent import CosmicWorksAIAgent
8+
9+
app = FastAPI()
10+
# Agent pool keyed by session_id to retain memories/history in-memory.
11+
agent_pool = {}
12+
13+
@app.get("/")
14+
def root():
15+
"""
16+
Health probe endpoint.
17+
"""
18+
return {"status": "ready"}
19+
20+
@app.post("/ai")
21+
def run_cosmic_works_ai_agent(request: AIRequest):
22+
"""
23+
Run the Cosmic Works AI agent.
24+
"""
25+
if request.session_id not in agent_pool:
26+
agent_pool[request.session_id] = CosmicWorksAIAgent(request.session_id)
27+
return { "message": agent_pool[request.session_id].run(request.prompt) }
28+
29+
if __name__ == "__main__":
30+
uvicorn.run("main:app", host="0.0.0.0", port=4242, reload=True)

Labs/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ pydantic==2.5.2
55
openai==1.6.0
66
tenacity==8.2.3
77
langchain==0.0.352
8-
tiktoken==0.5.2
8+
tiktoken==0.5.2
9+
fastapi==0.108.0
10+
uvicorn==0.25.0

0 commit comments

Comments
 (0)