11from dotenv import load_dotenv
2-
2+ from psycopg_pool import ConnectionPool
33import cocoindex
44import os
55
66
7+ @cocoindex .transform_flow ()
8+ def text_to_embedding (
9+ text : cocoindex .DataSlice [str ],
10+ ) -> cocoindex .DataSlice [list [float ]]:
11+ """
12+ Embed the text using a SentenceTransformer model.
13+ This is a shared logic between indexing and querying, so extract it as a function.
14+ """
15+ return text .transform (
16+ cocoindex .functions .SentenceTransformerEmbed (
17+ model = "sentence-transformers/all-MiniLM-L6-v2"
18+ )
19+ )
20+
21+
722@cocoindex .flow_def (name = "AmazonS3TextEmbedding" )
823def amazon_s3_text_embedding_flow (
924 flow_builder : cocoindex .FlowBuilder , data_scope : cocoindex .DataScope
@@ -19,7 +34,7 @@ def amazon_s3_text_embedding_flow(
1934 cocoindex .sources .AmazonS3 (
2035 bucket_name = bucket_name ,
2136 prefix = prefix ,
22- included_patterns = ["*.md" , "*.txt" , "*.docx" ],
37+ included_patterns = ["*.md" , "*.mdx" , "*. txt" , "*.docx" ],
2338 binary = False ,
2439 sqs_queue_url = sqs_queue_url ,
2540 )
@@ -36,11 +51,7 @@ def amazon_s3_text_embedding_flow(
3651 )
3752
3853 with doc ["chunks" ].row () as chunk :
39- chunk ["embedding" ] = chunk ["text" ].transform (
40- cocoindex .functions .SentenceTransformerEmbed (
41- model = "sentence-transformers/all-MiniLM-L6-v2"
42- )
43- )
54+ chunk ["embedding" ] = text_to_embedding (chunk ["text" ])
4455 doc_embeddings .collect (
4556 filename = doc ["filename" ],
4657 location = chunk ["location" ],
@@ -61,34 +72,45 @@ def amazon_s3_text_embedding_flow(
6172 )
6273
6374
64- query_handler = cocoindex .query .SimpleSemanticsQueryHandler (
65- name = "SemanticsSearch" ,
66- flow = amazon_s3_text_embedding_flow ,
67- target_name = "doc_embeddings" ,
68- query_transform_flow = lambda text : text .transform (
69- cocoindex .functions .SentenceTransformerEmbed (
70- model = "sentence-transformers/all-MiniLM-L6-v2"
71- )
72- ),
73- default_similarity_metric = cocoindex .VectorSimilarityMetric .COSINE_SIMILARITY ,
74- )
75+ def search (pool : ConnectionPool , query : str , top_k : int = 5 ):
76+ # Get the table name, for the export target in the amazon_s3_text_embedding_flow above.
77+ table_name = cocoindex .utils .get_target_storage_default_name (
78+ amazon_s3_text_embedding_flow , "doc_embeddings"
79+ )
80+ # Evaluate the transform flow defined above with the input query, to get the embedding.
81+ query_vector = text_to_embedding .eval (query )
82+ # Run the query and get the results.
83+ with pool .connection () as conn :
84+ with conn .cursor () as cur :
85+ cur .execute (
86+ f"""
87+ SELECT filename, text, embedding <=> %s::vector AS distance
88+ FROM { table_name } ORDER BY distance LIMIT %s
89+ """ ,
90+ (query_vector , top_k ),
91+ )
92+ return [
93+ {"filename" : row [0 ], "text" : row [1 ], "score" : 1.0 - row [2 ]}
94+ for row in cur .fetchall ()
95+ ]
7596
7697
7798def _main ():
78- # Use a `FlowLiveUpdater` to keep the flow data updated.
79- with cocoindex .FlowLiveUpdater (amazon_s3_text_embedding_flow ):
80- # Run queries in a loop to demonstrate the query capabilities.
81- while True :
82- query = input ("Enter search query (or Enter to quit): " )
83- if query == "" :
84- break
85- results , _ = query_handler .search (query , 10 )
86- print ("\n Search results:" )
87- for result in results :
88- print (f"[{ result .score :.3f} ] { result .data ['filename' ]} " )
89- print (f" { result .data ['text' ]} " )
90- print ("---" )
91- print ()
99+ # Initialize the database connection pool.
100+ pool = ConnectionPool (os .getenv ("COCOINDEX_DATABASE_URL" ))
101+ # Run queries in a loop to demonstrate the query capabilities.
102+ while True :
103+ query = input ("Enter search query (or Enter to quit): " )
104+ if query == "" :
105+ break
106+ # Run the query function with the database connection pool and the query.
107+ results = search (pool , query )
108+ print ("\n Search results:" )
109+ for result in results :
110+ print (f"[{ result ['score' ]:.3f} ] { result ['filename' ]} " )
111+ print (f" { result ['text' ]} " )
112+ print ("---" )
113+ print ()
92114
93115
94116if __name__ == "__main__" :
0 commit comments