1717from langgraph .graph .message import add_messages
1818from langgraph .prebuilt import create_react_agent
1919from chromadb .config import DEFAULT_TENANT , DEFAULT_DATABASE , Settings
20-
20+ from langchain_chroma import Chroma as ChromaClient
2121
2222from .extensions import postgresdb
2323from .config import Config
3131from chromadb .utils .embedding_functions import OpenAIEmbeddingFunction
3232
3333
34- async def get_chroma_collection ( api_key ):
35- chroma_client = await chromadb .AsyncHttpClient (
34+ async def get_chroma_client ( ):
35+ chroma_client = chromadb .HttpClient (
3636 host = Config .CHROMA_HOST ,
3737 port = Config .CHROMA_PORT ,
3838 ssl = False ,
@@ -41,30 +41,43 @@ async def get_chroma_collection(api_key):
4141 tenant = DEFAULT_TENANT ,
4242 database = DEFAULT_DATABASE ,
4343 )
44+ return chroma_client
45+
46+
47+ def get_embedding_function (api_key ):
48+ return OpenAIEmbeddingFunction (
49+ api_key = api_key ,
50+ model_name = "text-embedding-3-large" ,
51+ )
52+
4453
45- collection = await chroma_client .get_or_create_collection (
46- name = "chats" ,
47- embedding_function = OpenAIEmbeddingFunction (
48- api_key = api_key ,
49- model_name = "text-embedding-3-large" ,
50- ),
54+ def get_chroma_vectorstore (api_key ):
55+ chroma_client = get_chroma_client ()
56+ vectorstore = ChromaClient (
57+ client = chroma_client ,
58+ collection_name = "chats" ,
59+ create_collection_if_not_exists = True ,
60+ embedding_function = get_embedding_function (api_key ),
5161 )
52- return collection
62+ return vectorstore
5363
5464
55- async def add_to_chroma_collection (api_key , session_id , new_messages ):
56- collection = await get_chroma_collection (api_key )
57- res = await collection .add (
65+ def add_to_chroma_collection (
66+ api_key , session_id , new_messages : dict [str , str ]
67+ ) -> list :
68+ vectorstore = get_chroma_vectorstore (api_key )
69+ res : list = vectorstore .add_documents (
5870 documents = [
5971 {"content" : content , "metadata" : {"session_id" : session_id , "role" : role }}
6072 for role , content in new_messages .items ()
6173 ]
6274 )
75+ return res
6376
6477
6578async def get_retriever_tool (api_key ):
66- collection = await get_chroma_collection (api_key )
67- retriever = collection .as_retriever ()
79+ vectorstore = get_chroma_vectorstore (api_key )
80+ retriever = vectorstore .as_retriever ()
6881 retriever_tool = create_retriever_tool (
6982 retriever ,
7083 name = "chat_rag" ,
0 commit comments