diff --git a/examples/postgres_source/main.py b/examples/postgres_source/main.py index d43a60823..45bfa5e06 100644 --- a/examples/postgres_source/main.py +++ b/examples/postgres_source/main.py @@ -1,5 +1,15 @@ -import cocoindex +from typing import Any import os +import datetime + +from dotenv import load_dotenv +from psycopg_pool import ConnectionPool +from pgvector.psycopg import register_vector # type: ignore[import-untyped] +from psycopg.rows import dict_row +from numpy.typing import NDArray + +import numpy as np +import cocoindex @cocoindex.op.function() @@ -19,6 +29,21 @@ def make_full_description( return f"Category: {category}\nName: {name}\n\n{description}" +@cocoindex.transform_flow() +def text_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[NDArray[np.float32]]: + """ + Embed the text using a SentenceTransformer model. + This is a shared logic between indexing and querying, so extract it as a function. + """ + return text.transform( + cocoindex.functions.SentenceTransformerEmbed( + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + + @cocoindex.flow_def(name="PostgresProductIndexing") def postgres_product_indexing_flow( flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope @@ -32,13 +57,14 @@ def postgres_product_indexing_flow( table_name="source_products", # Optional. Use the default CocoIndex database if not specified. database=cocoindex.add_transient_auth_entry( - cocoindex.sources.DatabaseConnectionSpec( - url=os.getenv("SOURCE_DATABASE_URL"), + cocoindex.DatabaseConnectionSpec( + url=os.environ["SOURCE_DATABASE_URL"], ) ), # Optional. ordinal_column="modified_time", - ) + ), + refresh_interval=datetime.timedelta(seconds=30), ) indexed_product = data_scope.add_collector() @@ -80,3 +106,59 @@ def postgres_product_indexing_flow( ) ], ) + + +def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]: + # Get the table name, for the export target in the text_embedding_flow above. + table_name = cocoindex.utils.get_target_default_name( + postgres_product_indexing_flow, "output" + ) + # Evaluate the transform flow defined above with the input query, to get the embedding. + query_vector = text_to_embedding.eval(query) + # Run the query and get the results. + with pool.connection() as conn: + register_vector(conn) + with conn.cursor(row_factory=dict_row) as cur: + cur.execute( + f""" + SELECT + product_category, + product_name, + description, + amount, + total_value, + (embedding <=> %s) AS distance + FROM {table_name} + ORDER BY distance ASC + LIMIT %s + """, + (query_vector, top_k), + ) + return cur.fetchall() + + +def _main() -> None: + # Initialize the database connection pool. + pool = ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"]) + # Run queries in a loop to demonstrate the query capabilities. + while True: + query = input("Enter search query (or Enter to quit): ") + if query == "": + break + # Run the query function with the database connection pool and the query. + results = search(pool, query) + print("\nSearch results:") + for result in results: + score = 1.0 - result["distance"] + print( + f"[{score:.3f}] {result['product_category']} | {result['product_name']} | {result['amount']} | {result['total_value']}" + ) + print(f" {result['description']}") + print("---") + print() + + +if __name__ == "__main__": + load_dotenv() + cocoindex.init() + _main() diff --git a/examples/postgres_source/pyproject.toml b/examples/postgres_source/pyproject.toml index 83876f07b..5bd7c58be 100644 --- a/examples/postgres_source/pyproject.toml +++ b/examples/postgres_source/pyproject.toml @@ -3,7 +3,13 @@ name = "postgres-source" version = "0.1.0" description = "Demonstrate how to use Postgres tables as the source for CocoIndex." requires-python = ">=3.11" -dependencies = ["cocoindex[embeddings]>=0.2.1"] +dependencies = [ + "cocoindex[embeddings]>=0.2.1", + "python-dotenv>=1.0.1", + "pgvector>=0.4.1", + "psycopg[binary,pool]", + "numpy", +] [tool.setuptools] packages = []