55from typing import Annotated
66
77import typer
8+ from azure .cosmos import CosmosClient , PartitionKey
89from dotenv import load_dotenv
910from langchain_community .document_loaders .csv_loader import CSVLoader
1011from langchain_community .vectorstores import FAISS , VectorStore
12+ from langchain_community .vectorstores .azure_cosmos_db_no_sql import AzureCosmosDBNoSqlVectorSearch
1113from langchain_community .vectorstores .azuresearch import AzureSearch
1214from langchain_core .documents import Document
1315from langchain_openai import AzureOpenAIEmbeddings
1618
1719
1820class VectorStoreType (str , Enum ):
19- AzureAISearch = "azureaisearch"
2021 Faiss = "faiss"
22+ AzureAISearch = "azureaisearch"
23+ AzureCosmosDbNoSql = "azurecosmosdbnosql"
2124
2225
2326def get_log_level (debug : bool ) -> int :
@@ -41,10 +44,39 @@ def get_embeddings():
4144 )
4245
4346
47+ def get_cosmos_client ():
48+ return CosmosClient (
49+ url = getenv ("AZURE_COSMOSDB_ENDPOINT" ),
50+ credential = getenv ("AZURE_COSMOSDB_KEY" ),
51+ )
52+
53+
4454def get_local_vector_store_path (identifier : str ):
4555 return f"./artifacts/vectorstore/{ identifier } "
4656
4757
58+ def get_vector_embedding_policy ():
59+ return {
60+ "vectorEmbeddings" : [
61+ {
62+ "path" : "/embedding" ,
63+ "dataType" : "float32" ,
64+ "distanceFunction" : "cosine" ,
65+ "dimensions" : 1536 ,
66+ }
67+ ]
68+ }
69+
70+
71+ def get_indexing_policy ():
72+ return {
73+ "indexingMode" : "consistent" ,
74+ "includedPaths" : [{"path" : "/*" }],
75+ "excludedPaths" : [{"path" : '/"_etag"/?' }],
76+ "vectorIndexes" : [{"path" : "/embedding" , "type" : "quantizedFlat" }],
77+ }
78+
79+
4880def create_azure_search (index_name : str ) -> AzureSearch :
4981 return AzureSearch (
5082 azure_search_endpoint = getenv ("AZURE_AI_SEARCH_ENDPOINT" ),
@@ -59,40 +91,69 @@ def get_vector_store(
5991 vector_store_type : VectorStoreType ,
6092 identifier : str ,
6193) -> VectorStore :
62- if vector_store_type == VectorStoreType .AzureAISearch :
63- logging .info ("Creating Azure AI Search vector store" )
64- return create_azure_search (identifier )
65- elif vector_store_type == VectorStoreType .Faiss :
94+ if vector_store_type == VectorStoreType .Faiss :
6695 logging .info ("Creating Faiss vector store" )
6796 return FAISS .load_local (
6897 folder_path = get_local_vector_store_path (identifier ),
6998 embeddings = get_embeddings (),
7099 allow_dangerous_deserialization = True ,
71100 )
101+ elif vector_store_type == VectorStoreType .AzureAISearch :
102+ logging .info ("Creating Azure AI Search vector store" )
103+ return create_azure_search (identifier )
104+ elif vector_store_type == VectorStoreType .AzureCosmosDbNoSql :
105+ logging .info ("Creating Azure Cosmos DB NoSQL vector store" )
106+ cosmos_database_name = "langchain_python_db"
107+ return AzureCosmosDBNoSqlVectorSearch (
108+ cosmos_client = get_cosmos_client (),
109+ embedding = get_embeddings (),
110+ vector_embedding_policy = get_vector_embedding_policy (),
111+ indexing_policy = get_indexing_policy (),
112+ cosmos_container_properties = {"partition_key" : PartitionKey (path = "/id" )},
113+ cosmos_database_properties = {"id" : cosmos_database_name },
114+ database_name = cosmos_database_name ,
115+ container_name = "langchain_python_container" ,
116+ )
72117
73118
74- def _create_vector_store (
119+ def create_vector_store (
75120 vector_store_type : VectorStoreType ,
76121 identifier : str ,
77122 documents : list [Document ],
78123) -> VectorStore :
79- if vector_store_type == VectorStoreType .AzureAISearch :
124+ if vector_store_type == VectorStoreType .Faiss :
125+ logging .info ("Creating Faiss vector store" )
126+ vector_store : FAISS = FAISS .from_documents (
127+ documents = documents ,
128+ embedding = get_embeddings (),
129+ )
130+ vector_store .save_local (folder_path = get_local_vector_store_path (identifier ))
131+ return
132+ elif vector_store_type == VectorStoreType .AzureAISearch :
80133 logging .info ("Creating Azure AI Search vector store" )
81134 vector_store = create_azure_search (identifier )
82135 vector_store .add_documents (documents = documents )
83136 return
84- elif vector_store_type == VectorStoreType .Faiss :
85- logging .info ("Creating Faiss vector store" )
86- vector_store : FAISS = FAISS .from_documents (
137+ elif vector_store_type == VectorStoreType .AzureCosmosDbNoSql :
138+ logging .info ("Creating Azure Cosmos DB NoSQL vector store" )
139+ cosmos_database_name = "langchain_python_db"
140+
141+ AzureCosmosDBNoSqlVectorSearch .from_documents (
87142 documents = documents ,
88143 embedding = get_embeddings (),
144+ cosmos_client = get_cosmos_client (),
145+ database_name = cosmos_database_name ,
146+ container_name = "langchain_python_container" ,
147+ vector_embedding_policy = get_vector_embedding_policy (),
148+ indexing_policy = get_indexing_policy (),
149+ cosmos_database_properties = {"id" : cosmos_database_name },
150+ cosmos_container_properties = {"partition_key" : PartitionKey (path = "/id" )},
89151 )
90- vector_store .save_local (folder_path = get_local_vector_store_path (identifier ))
91152 return
92153
93154
94155@app .command ()
95- def create_vector_store (
156+ def create (
96157 input_csv_file_path : Annotated [str , typer .Option (help = "Path to the input CSV file" )] = "./data/contoso_rules.csv" ,
97158 identifier = "contoso_rules" ,
98159 vector_store_type : Annotated [VectorStoreType , typer .Option (case_sensitive = False )] = VectorStoreType .Faiss ,
@@ -108,7 +169,7 @@ def create_vector_store(
108169 return
109170
110171 # Create vector store
111- _create_vector_store (
172+ create_vector_store (
112173 vector_store_type = vector_store_type ,
113174 identifier = identifier ,
114175 documents = documents ,
@@ -120,6 +181,7 @@ def search(
120181 identifier = "contoso_rules" ,
121182 vector_store_type : Annotated [VectorStoreType , typer .Option (case_sensitive = False )] = VectorStoreType .Faiss ,
122183 query : Annotated [str , typer .Option (help = "Query to search" )] = "社内の機密情報は外部に漏らさないでください" ,
184+ k : Annotated [int , typer .Option (help = "Number of documents to retrieve" )] = 5 ,
123185 debug : Annotated [bool , typer .Option (help = "Enable debug mode" )] = False ,
124186):
125187 setup_logging (debug )
@@ -131,12 +193,12 @@ def search(
131193 )
132194
133195 # Search
134- result = vector_store .similarity_search_with_relevance_scores (
196+ got_documents = vector_store .similarity_search (
135197 query = query ,
136- k = 5 ,
137- score_threshold = 0.5 ,
198+ k = k ,
138199 )
139- pprint (result )
200+ for document in got_documents :
201+ pprint (document .page_content )
140202
141203
142204if __name__ == "__main__" :
0 commit comments