Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 86 additions & 4 deletions examples/postgres_source/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
import cocoindex
from typing import Any
import os
import datetime

from dotenv import load_dotenv
from psycopg_pool import ConnectionPool
from pgvector.psycopg import register_vector # type: ignore[import-untyped]
from psycopg.rows import dict_row
from numpy.typing import NDArray

import numpy as np
import cocoindex


@cocoindex.op.function()
Expand All @@ -19,6 +29,21 @@ def make_full_description(
return f"Category: {category}\nName: {name}\n\n{description}"


@cocoindex.transform_flow()
def text_to_embedding(
text: cocoindex.DataSlice[str],
) -> cocoindex.DataSlice[NDArray[np.float32]]:
"""
Embed the text using a SentenceTransformer model.
This is a shared logic between indexing and querying, so extract it as a function.
"""
return text.transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"
)
)


@cocoindex.flow_def(name="PostgresProductIndexing")
def postgres_product_indexing_flow(
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
Expand All @@ -32,13 +57,14 @@ def postgres_product_indexing_flow(
table_name="source_products",
# Optional. Use the default CocoIndex database if not specified.
database=cocoindex.add_transient_auth_entry(
cocoindex.sources.DatabaseConnectionSpec(
url=os.getenv("SOURCE_DATABASE_URL"),
cocoindex.DatabaseConnectionSpec(
url=os.environ["SOURCE_DATABASE_URL"],
)
),
# Optional.
ordinal_column="modified_time",
)
),
refresh_interval=datetime.timedelta(seconds=30),
)

indexed_product = data_scope.add_collector()
Expand Down Expand Up @@ -80,3 +106,59 @@ def postgres_product_indexing_flow(
)
],
)


def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
# Get the table name, for the export target in the text_embedding_flow above.
table_name = cocoindex.utils.get_target_default_name(
postgres_product_indexing_flow, "output"
)
# Evaluate the transform flow defined above with the input query, to get the embedding.
query_vector = text_to_embedding.eval(query)
# Run the query and get the results.
with pool.connection() as conn:
register_vector(conn)
with conn.cursor(row_factory=dict_row) as cur:
cur.execute(
f"""
SELECT
product_category,
product_name,
description,
amount,
total_value,
(embedding <=> %s) AS distance
FROM {table_name}
ORDER BY distance ASC
LIMIT %s
""",
(query_vector, top_k),
)
return cur.fetchall()


def _main() -> None:
# Initialize the database connection pool.
pool = ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"])
# Run queries in a loop to demonstrate the query capabilities.
while True:
query = input("Enter search query (or Enter to quit): ")
if query == "":
break
# Run the query function with the database connection pool and the query.
results = search(pool, query)
print("\nSearch results:")
for result in results:
score = 1.0 - result["distance"]
print(
f"[{score:.3f}] {result['product_category']} | {result['product_name']} | {result['amount']} | {result['total_value']}"
)
print(f" {result['description']}")
print("---")
print()


if __name__ == "__main__":
load_dotenv()
cocoindex.init()
_main()
8 changes: 7 additions & 1 deletion examples/postgres_source/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ name = "postgres-source"
version = "0.1.0"
description = "Demonstrate how to use Postgres tables as the source for CocoIndex."
requires-python = ">=3.11"
dependencies = ["cocoindex[embeddings]>=0.2.1"]
dependencies = [
"cocoindex[embeddings]>=0.2.1",
"python-dotenv>=1.0.1",
"pgvector>=0.4.1",
"psycopg[binary,pool]",
"numpy",
]

[tool.setuptools]
packages = []