Skip to content

Commit 40e04cf

Browse files
committed
reorganize the example to make it simpler
1 parent 0c2bbb9 commit 40e04cf

File tree

4 files changed

+186
-204
lines changed

4 files changed

+186
-204
lines changed

examples/postgres_embedding/.env

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,4 @@
33
COCOINDEX_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex
44

55
# Source Database (for reading data - can be different from CocoIndex DB)
6-
SOURCE_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/source_data
7-
8-
# ========================================
9-
# Configuration for test_multiple table
10-
# ========================================
11-
TABLE_NAME=test_multiple
12-
KEY_COLUMNS_FOR_MULTIPLE_KEYS=product_category,product_name
13-
INDEXING_COLUMN=description
14-
ORDINAL_COLUMN=modified_time
6+
SOURCE_DATABASE_URL=postgres://cocoindex:cocoindex@localhost/cocoindex
Lines changed: 89 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,228 +1,123 @@
1-
from dotenv import load_dotenv
2-
from psycopg_pool import ConnectionPool
3-
from pgvector.psycopg import register_vector
4-
from typing import Any
51
import cocoindex
62
import os
7-
import sys
83

9-
os.environ["RUST_BACKTRACE"] = "1"
10-
os.environ["COCOINDEX_LOG"] = "debug"
114

12-
13-
@cocoindex.transform_flow()
14-
def text_to_embedding(
15-
text: cocoindex.DataSlice[str],
16-
) -> cocoindex.DataSlice[list[float]]:
5+
@cocoindex.flow_def(name="PostgresMessageEmbedding")
6+
def postgres_message_embedding_flow(
7+
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
8+
) -> None:
179
"""
18-
Embed the text using a SentenceTransformer model.
19-
This is a shared logic between indexing and querying, so extract it as a function.
10+
Define a flow that reads data from a PostgreSQL table, generates embeddings,
11+
and stores them in another PostgreSQL table with pgvector.
2012
"""
21-
return text.transform(
22-
cocoindex.functions.SentenceTransformerEmbed(
23-
model="sentence-transformers/all-MiniLM-L6-v2"
13+
14+
data_scope["messages"] = flow_builder.add_source(
15+
cocoindex.sources.Postgres(
16+
table_name="source_messages",
17+
database=cocoindex.add_transient_auth_entry(
18+
cocoindex.sources.DatabaseConnectionSpec(
19+
url=os.getenv("SOURCE_DATABASE_URL"),
20+
)
21+
),
22+
ordinal_column="created_at",
2423
)
2524
)
2625

27-
28-
def get_key_columns_from_env() -> list[str]:
29-
"""
30-
Get key columns from environment variables.
31-
Ensures only one of KEY_COLUMN_FOR_SINGLE_KEY or KEY_COLUMNS_FOR_MULTIPLE_KEYS is set.
32-
33-
Returns:
34-
List of key column names
35-
36-
Raises:
37-
SystemExit: If configuration is invalid
38-
"""
39-
single_key = os.environ.get("KEY_COLUMN_FOR_SINGLE_KEY")
40-
multiple_keys = os.environ.get("KEY_COLUMNS_FOR_MULTIPLE_KEYS")
41-
42-
# Check that exactly one is set
43-
if single_key and multiple_keys:
44-
print(
45-
"❌ Error: Both KEY_COLUMN_FOR_SINGLE_KEY and KEY_COLUMNS_FOR_MULTIPLE_KEYS are set"
26+
message_embeddings = data_scope.add_collector()
27+
with data_scope["messages"].row() as message_row:
28+
# Use the indexing column for embedding generation
29+
message_row["embedding"] = message_row["message"].transform(
30+
cocoindex.functions.SentenceTransformerEmbed(
31+
model="sentence-transformers/all-MiniLM-L6-v2"
32+
)
4633
)
47-
print(" Please set only one of them:")
48-
print(" - KEY_COLUMN_FOR_SINGLE_KEY=id (for single primary key)")
49-
print(
50-
" - KEY_COLUMNS_FOR_MULTIPLE_KEYS=product_category,product_name (for composite primary key)"
34+
# Collect the data - include key columns and content
35+
message_embeddings.collect(
36+
id=message_row["id"],
37+
author=message_row["author"],
38+
message=message_row["message"],
39+
embedding=message_row["embedding"],
5140
)
52-
sys.exit(1)
5341

54-
if not single_key and not multiple_keys:
55-
print(
56-
"❌ Error: Neither KEY_COLUMN_FOR_SINGLE_KEY nor KEY_COLUMNS_FOR_MULTIPLE_KEYS is set"
57-
)
58-
print(" Please set one of them:")
59-
print(" - KEY_COLUMN_FOR_SINGLE_KEY=id (for single primary key)")
60-
print(
61-
" - KEY_COLUMNS_FOR_MULTIPLE_KEYS=product_category,product_name (for composite primary key)"
62-
)
63-
sys.exit(1)
42+
message_embeddings.export(
43+
"message_embeddings",
44+
cocoindex.targets.Postgres(),
45+
primary_key_fields=["id"],
46+
vector_indexes=[
47+
cocoindex.VectorIndexDef(
48+
field_name="embedding",
49+
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
50+
)
51+
],
52+
)
6453

65-
if single_key:
66-
# Single primary key
67-
return [single_key.strip()]
68-
else:
69-
# Multiple primary keys (composite key)
70-
return [col.strip() for col in multiple_keys.split(",")]
7154

55+
@cocoindex.op.function()
56+
def calculate_total_value(
57+
price: float,
58+
amount: int,
59+
) -> float:
60+
return price * amount
7261

73-
def is_single_key() -> bool:
74-
"""
75-
Check if using single key or composite key configuration.
7662

77-
Returns:
78-
bool: True if using single key, False if using composite key
79-
"""
80-
return bool(os.environ.get("KEY_COLUMN_FOR_SINGLE_KEY"))
63+
@cocoindex.op.function()
64+
def make_full_description(
65+
category: str,
66+
name: str,
67+
description: str,
68+
) -> str:
69+
return f"Category: {category}\nName: {name}\n\n{description}"
8170

8271

83-
@cocoindex.flow_def(name="PostgresEmbedding")
84-
def postgres_embedding_flow(
72+
@cocoindex.flow_def(name="PostgresProductEmbedding")
73+
def postgres_product_embedding_flow(
8574
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
8675
) -> None:
8776
"""
8877
Define a flow that reads data from a PostgreSQL table, generates embeddings,
8978
and stores them in another PostgreSQL table with pgvector.
9079
"""
91-
# Required environment variables
92-
table_name = os.environ["TABLE_NAME"]
93-
indexing_column = os.environ["INDEXING_COLUMN"]
94-
95-
# Get key columns from environment
96-
key_columns = get_key_columns_from_env()
97-
98-
# Optional environment variables
99-
ordinal_column = os.environ.get("ORDINAL_COLUMN")
100-
101-
# Only include the data column - primary keys are automatically read by the PostgreSQL source
102-
included_columns = [indexing_column]
103-
104-
# Get source database URL for the Postgres source
105-
source_db_url = os.environ.get("SOURCE_DATABASE_URL")
106-
if not source_db_url:
107-
print("❌ Error: SOURCE_DATABASE_URL environment variable is required")
108-
print(" This should point to the database containing the source table")
109-
sys.exit(1)
110-
111-
# Create auth entry for the source database
112-
source_db_conn = cocoindex.add_auth_entry(
113-
"source_db_conn", cocoindex.DatabaseConnectionSpec(url=source_db_url)
114-
)
115-
116-
# Read from source PostgreSQL table with only the specified columns
117-
postgres_source_kwargs = {
118-
"table_name": table_name,
119-
"database": source_db_conn,
120-
"included_columns": included_columns,
121-
}
122-
if ordinal_column:
123-
postgres_source_kwargs["ordinal_column"] = ordinal_column
124-
125-
data_scope["documents"] = flow_builder.add_source(
126-
cocoindex.sources.Postgres(**postgres_source_kwargs)
80+
data_scope["products"] = flow_builder.add_source(
81+
cocoindex.sources.Postgres(
82+
table_name="source_products",
83+
database=cocoindex.add_transient_auth_entry(
84+
cocoindex.sources.DatabaseConnectionSpec(
85+
url=os.getenv("SOURCE_DATABASE_URL"),
86+
)
87+
),
88+
)
12789
)
12890

129-
document_embeddings = data_scope.add_collector()
130-
131-
with data_scope["documents"].row() as row:
132-
# Use the indexing column for embedding generation
133-
row["text_embedding"] = text_to_embedding(row[indexing_column])
134-
# Collect the data - include key columns and content
135-
collect_data = {
136-
"content": row[indexing_column],
137-
"text_embedding": row["text_embedding"],
138-
}
139-
140-
# Add each key column as a separate field
141-
for key_col in key_columns:
142-
if is_single_key():
143-
collect_data[key_col] = row[key_col]
144-
else:
145-
collect_data[key_col] = row["_key"][key_col]
146-
147-
document_embeddings.collect(**collect_data)
91+
product_embeddings = data_scope.add_collector()
92+
with data_scope["products"].row() as product:
93+
product["full_description"] = flow_builder.transform(
94+
make_full_description,
95+
product["_key"]["product_category"],
96+
product["_key"]["product_name"],
97+
product["description"],
98+
)
99+
product["embedding"] = product["full_description"].transform(
100+
cocoindex.functions.SentenceTransformerEmbed(
101+
model="sentence-transformers/all-MiniLM-L6-v2"
102+
)
103+
)
104+
product_embeddings.collect(
105+
product_category=product["_key"]["product_category"],
106+
product_name=product["_key"]["product_name"],
107+
description=product["description"],
108+
price=product["price"],
109+
amount=product["amount"],
110+
embedding=product["embedding"],
111+
)
148112

149-
document_embeddings.export(
150-
"document_embeddings",
113+
product_embeddings.export(
114+
"product_embeddings",
151115
cocoindex.targets.Postgres(),
152-
primary_key_fields=key_columns,
116+
primary_key_fields=["product_category", "product_name"],
153117
vector_indexes=[
154118
cocoindex.VectorIndexDef(
155-
field_name="text_embedding",
119+
field_name="embedding",
156120
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
157121
)
158122
],
159123
)
160-
161-
162-
def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
163-
# Get the table name, for the export target in the postgres_embedding_flow above.
164-
table_name = cocoindex.utils.get_target_default_name(
165-
postgres_embedding_flow, "document_embeddings"
166-
)
167-
168-
# Get key columns configuration
169-
key_columns = get_key_columns_from_env()
170-
# Build SELECT clause with all key columns
171-
key_columns_select = ", ".join(key_columns)
172-
173-
# Evaluate the transform flow defined above with the input query, to get the embedding.
174-
query_vector = text_to_embedding.eval(query)
175-
# Run the query and get the results.
176-
with pool.connection() as conn:
177-
register_vector(conn)
178-
with conn.cursor() as cur:
179-
cur.execute(
180-
f"""
181-
SELECT content, text_embedding <=> %s::vector AS distance, {key_columns_select}
182-
FROM {table_name} ORDER BY distance LIMIT %s
183-
""",
184-
(query_vector, top_k),
185-
)
186-
results = []
187-
for row in cur.fetchall():
188-
result = {
189-
"content": row[0],
190-
"score": 1.0 - row[1],
191-
"key": "__".join(str(x) for x in row[2:]),
192-
}
193-
results.append(result)
194-
return results
195-
196-
197-
def _main() -> None:
198-
# Initialize the database connection pool for CocoIndex database (where embeddings are stored)
199-
cocoindex_db_url = os.getenv("COCOINDEX_DATABASE_URL")
200-
if not cocoindex_db_url:
201-
print("❌ Error: COCOINDEX_DATABASE_URL environment variable is required")
202-
print(" This should point to the database where embeddings will be stored")
203-
sys.exit(1)
204-
205-
pool = ConnectionPool(cocoindex_db_url)
206-
207-
postgres_embedding_flow.setup()
208-
with cocoindex.FlowLiveUpdater(postgres_embedding_flow) as updater:
209-
# Run queries in a loop to demonstrate the query capabilities.
210-
while True:
211-
query = input("Enter search query (or Enter to quit): ")
212-
if query == "":
213-
break
214-
# Run the query function with the database connection pool and the query.
215-
results = search(pool, query)
216-
print("\nSearch results:")
217-
for result in results:
218-
print(
219-
f"[{result['score']:.3f}] {result['content']} key: {result['key']}"
220-
)
221-
print("---")
222-
print()
223-
224-
225-
if __name__ == "__main__":
226-
load_dotenv()
227-
cocoindex.init()
228-
_main()

0 commit comments

Comments
 (0)