Skip to content

Commit 99b0fb1

Browse files
authored
examples: add query path for postgres source example (#944)
1 parent aea63d7 commit 99b0fb1

File tree

2 files changed

+93
-5
lines changed

2 files changed

+93
-5
lines changed

examples/postgres_source/main.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1-
import cocoindex
1+
from typing import Any
22
import os
3+
import datetime
4+
5+
from dotenv import load_dotenv
6+
from psycopg_pool import ConnectionPool
7+
from pgvector.psycopg import register_vector # type: ignore[import-untyped]
8+
from psycopg.rows import dict_row
9+
from numpy.typing import NDArray
10+
11+
import numpy as np
12+
import cocoindex
313

414

515
@cocoindex.op.function()
@@ -19,6 +29,21 @@ def make_full_description(
1929
return f"Category: {category}\nName: {name}\n\n{description}"
2030

2131

32+
@cocoindex.transform_flow()
33+
def text_to_embedding(
34+
text: cocoindex.DataSlice[str],
35+
) -> cocoindex.DataSlice[NDArray[np.float32]]:
36+
"""
37+
Embed the text using a SentenceTransformer model.
38+
This is a shared logic between indexing and querying, so extract it as a function.
39+
"""
40+
return text.transform(
41+
cocoindex.functions.SentenceTransformerEmbed(
42+
model="sentence-transformers/all-MiniLM-L6-v2"
43+
)
44+
)
45+
46+
2247
@cocoindex.flow_def(name="PostgresProductIndexing")
2348
def postgres_product_indexing_flow(
2449
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
@@ -32,13 +57,14 @@ def postgres_product_indexing_flow(
3257
table_name="source_products",
3358
# Optional. Use the default CocoIndex database if not specified.
3459
database=cocoindex.add_transient_auth_entry(
35-
cocoindex.sources.DatabaseConnectionSpec(
36-
url=os.getenv("SOURCE_DATABASE_URL"),
60+
cocoindex.DatabaseConnectionSpec(
61+
url=os.environ["SOURCE_DATABASE_URL"],
3762
)
3863
),
3964
# Optional.
4065
ordinal_column="modified_time",
41-
)
66+
),
67+
refresh_interval=datetime.timedelta(seconds=30),
4268
)
4369

4470
indexed_product = data_scope.add_collector()
@@ -80,3 +106,59 @@ def postgres_product_indexing_flow(
80106
)
81107
],
82108
)
109+
110+
111+
def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
112+
# Get the table name, for the export target in the text_embedding_flow above.
113+
table_name = cocoindex.utils.get_target_default_name(
114+
postgres_product_indexing_flow, "output"
115+
)
116+
# Evaluate the transform flow defined above with the input query, to get the embedding.
117+
query_vector = text_to_embedding.eval(query)
118+
# Run the query and get the results.
119+
with pool.connection() as conn:
120+
register_vector(conn)
121+
with conn.cursor(row_factory=dict_row) as cur:
122+
cur.execute(
123+
f"""
124+
SELECT
125+
product_category,
126+
product_name,
127+
description,
128+
amount,
129+
total_value,
130+
(embedding <=> %s) AS distance
131+
FROM {table_name}
132+
ORDER BY distance ASC
133+
LIMIT %s
134+
""",
135+
(query_vector, top_k),
136+
)
137+
return cur.fetchall()
138+
139+
140+
def _main() -> None:
141+
# Initialize the database connection pool.
142+
pool = ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"])
143+
# Run queries in a loop to demonstrate the query capabilities.
144+
while True:
145+
query = input("Enter search query (or Enter to quit): ")
146+
if query == "":
147+
break
148+
# Run the query function with the database connection pool and the query.
149+
results = search(pool, query)
150+
print("\nSearch results:")
151+
for result in results:
152+
score = 1.0 - result["distance"]
153+
print(
154+
f"[{score:.3f}] {result['product_category']} | {result['product_name']} | {result['amount']} | {result['total_value']}"
155+
)
156+
print(f" {result['description']}")
157+
print("---")
158+
print()
159+
160+
161+
if __name__ == "__main__":
162+
load_dotenv()
163+
cocoindex.init()
164+
_main()

examples/postgres_source/pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@ name = "postgres-source"
33
version = "0.1.0"
44
description = "Demonstrate how to use Postgres tables as the source for CocoIndex."
55
requires-python = ">=3.11"
6-
dependencies = ["cocoindex[embeddings]>=0.2.1"]
6+
dependencies = [
7+
"cocoindex[embeddings]>=0.2.1",
8+
"python-dotenv>=1.0.1",
9+
"pgvector>=0.4.1",
10+
"psycopg[binary,pool]",
11+
"numpy",
12+
]
713

814
[tool.setuptools]
915
packages = []

0 commit comments

Comments
 (0)