1- import cocoindex
1+ from typing import Any
22import os
3+ import datetime
4+
5+ from dotenv import load_dotenv
6+ from psycopg_pool import ConnectionPool
7+ from pgvector .psycopg import register_vector # type: ignore[import-untyped]
8+ from psycopg .rows import dict_row
9+ from numpy .typing import NDArray
10+
11+ import numpy as np
12+ import cocoindex
313
414
515@cocoindex .op .function ()
@@ -19,6 +29,21 @@ def make_full_description(
1929 return f"Category: { category } \n Name: { name } \n \n { description } "
2030
2131
32+ @cocoindex .transform_flow ()
33+ def text_to_embedding (
34+ text : cocoindex .DataSlice [str ],
35+ ) -> cocoindex .DataSlice [NDArray [np .float32 ]]:
36+ """
37+ Embed the text using a SentenceTransformer model.
38+ This is a shared logic between indexing and querying, so extract it as a function.
39+ """
40+ return text .transform (
41+ cocoindex .functions .SentenceTransformerEmbed (
42+ model = "sentence-transformers/all-MiniLM-L6-v2"
43+ )
44+ )
45+
46+
2247@cocoindex .flow_def (name = "PostgresProductIndexing" )
2348def postgres_product_indexing_flow (
2449 flow_builder : cocoindex .FlowBuilder , data_scope : cocoindex .DataScope
@@ -32,13 +57,14 @@ def postgres_product_indexing_flow(
3257 table_name = "source_products" ,
3358 # Optional. Use the default CocoIndex database if not specified.
3459 database = cocoindex .add_transient_auth_entry (
35- cocoindex .sources . DatabaseConnectionSpec (
36- url = os .getenv ( "SOURCE_DATABASE_URL" ) ,
60+ cocoindex .DatabaseConnectionSpec (
61+ url = os .environ [ "SOURCE_DATABASE_URL" ] ,
3762 )
3863 ),
3964 # Optional.
4065 ordinal_column = "modified_time" ,
41- )
66+ ),
67+ refresh_interval = datetime .timedelta (seconds = 30 ),
4268 )
4369
4470 indexed_product = data_scope .add_collector ()
@@ -80,3 +106,59 @@ def postgres_product_indexing_flow(
80106 )
81107 ],
82108 )
109+
110+
111+ def search (pool : ConnectionPool , query : str , top_k : int = 5 ) -> list [dict [str , Any ]]:
112+ # Get the table name, for the export target in the text_embedding_flow above.
113+ table_name = cocoindex .utils .get_target_default_name (
114+ postgres_product_indexing_flow , "output"
115+ )
116+ # Evaluate the transform flow defined above with the input query, to get the embedding.
117+ query_vector = text_to_embedding .eval (query )
118+ # Run the query and get the results.
119+ with pool .connection () as conn :
120+ register_vector (conn )
121+ with conn .cursor (row_factory = dict_row ) as cur :
122+ cur .execute (
123+ f"""
124+ SELECT
125+ product_category,
126+ product_name,
127+ description,
128+ amount,
129+ total_value,
130+ (embedding <=> %s) AS distance
131+ FROM { table_name }
132+ ORDER BY distance ASC
133+ LIMIT %s
134+ """ ,
135+ (query_vector , top_k ),
136+ )
137+ return cur .fetchall ()
138+
139+
140+ def _main () -> None :
141+ # Initialize the database connection pool.
142+ pool = ConnectionPool (os .environ ["COCOINDEX_DATABASE_URL" ])
143+ # Run queries in a loop to demonstrate the query capabilities.
144+ while True :
145+ query = input ("Enter search query (or Enter to quit): " )
146+ if query == "" :
147+ break
148+ # Run the query function with the database connection pool and the query.
149+ results = search (pool , query )
150+ print ("\n Search results:" )
151+ for result in results :
152+ score = 1.0 - result ["distance" ]
153+ print (
154+ f"[{ score :.3f} ] { result ['product_category' ]} | { result ['product_name' ]} | { result ['amount' ]} | { result ['total_value' ]} "
155+ )
156+ print (f" { result ['description' ]} " )
157+ print ("---" )
158+ print ()
159+
160+
161+ if __name__ == "__main__" :
162+ load_dotenv ()
163+ cocoindex .init ()
164+ _main ()
0 commit comments