|
1 | | -from dotenv import load_dotenv |
2 | | -from psycopg_pool import ConnectionPool |
3 | | -from pgvector.psycopg import register_vector |
4 | | -from typing import Any |
5 | 1 | import cocoindex |
6 | 2 | import os |
7 | | -import sys |
8 | 3 |
|
9 | | -os.environ["RUST_BACKTRACE"] = "1" |
10 | | -os.environ["COCOINDEX_LOG"] = "debug" |
11 | 4 |
|
12 | | - |
13 | | -@cocoindex.transform_flow() |
14 | | -def text_to_embedding( |
15 | | - text: cocoindex.DataSlice[str], |
16 | | -) -> cocoindex.DataSlice[list[float]]: |
| 5 | +@cocoindex.flow_def(name="PostgresMessageEmbedding") |
| 6 | +def postgres_message_embedding_flow( |
| 7 | + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope |
| 8 | +) -> None: |
17 | 9 | """ |
18 | | - Embed the text using a SentenceTransformer model. |
19 | | - This is a shared logic between indexing and querying, so extract it as a function. |
| 10 | + Define a flow that reads data from a PostgreSQL table, generates embeddings, |
| 11 | + and stores them in another PostgreSQL table with pgvector. |
20 | 12 | """ |
21 | | - return text.transform( |
22 | | - cocoindex.functions.SentenceTransformerEmbed( |
23 | | - model="sentence-transformers/all-MiniLM-L6-v2" |
| 13 | + |
| 14 | + data_scope["messages"] = flow_builder.add_source( |
| 15 | + cocoindex.sources.Postgres( |
| 16 | + table_name="source_messages", |
| 17 | + database=cocoindex.add_transient_auth_entry( |
| 18 | + cocoindex.sources.DatabaseConnectionSpec( |
| 19 | + url=os.getenv("SOURCE_DATABASE_URL"), |
| 20 | + ) |
| 21 | + ), |
| 22 | + ordinal_column="created_at", |
24 | 23 | ) |
25 | 24 | ) |
26 | 25 |
|
27 | | - |
28 | | -def get_key_columns_from_env() -> list[str]: |
29 | | - """ |
30 | | - Get key columns from environment variables. |
31 | | - Ensures only one of KEY_COLUMN_FOR_SINGLE_KEY or KEY_COLUMNS_FOR_MULTIPLE_KEYS is set. |
32 | | -
|
33 | | - Returns: |
34 | | - List of key column names |
35 | | -
|
36 | | - Raises: |
37 | | - SystemExit: If configuration is invalid |
38 | | - """ |
39 | | - single_key = os.environ.get("KEY_COLUMN_FOR_SINGLE_KEY") |
40 | | - multiple_keys = os.environ.get("KEY_COLUMNS_FOR_MULTIPLE_KEYS") |
41 | | - |
42 | | - # Check that exactly one is set |
43 | | - if single_key and multiple_keys: |
44 | | - print( |
45 | | - "❌ Error: Both KEY_COLUMN_FOR_SINGLE_KEY and KEY_COLUMNS_FOR_MULTIPLE_KEYS are set" |
| 26 | + message_embeddings = data_scope.add_collector() |
| 27 | + with data_scope["messages"].row() as message_row: |
| 28 | + # Use the indexing column for embedding generation |
| 29 | + message_row["embedding"] = message_row["message"].transform( |
| 30 | + cocoindex.functions.SentenceTransformerEmbed( |
| 31 | + model="sentence-transformers/all-MiniLM-L6-v2" |
| 32 | + ) |
46 | 33 | ) |
47 | | - print(" Please set only one of them:") |
48 | | - print(" - KEY_COLUMN_FOR_SINGLE_KEY=id (for single primary key)") |
49 | | - print( |
50 | | - " - KEY_COLUMNS_FOR_MULTIPLE_KEYS=product_category,product_name (for composite primary key)" |
| 34 | + # Collect the data - include key columns and content |
| 35 | + message_embeddings.collect( |
| 36 | + id=message_row["id"], |
| 37 | + author=message_row["author"], |
| 38 | + message=message_row["message"], |
| 39 | + embedding=message_row["embedding"], |
51 | 40 | ) |
52 | | - sys.exit(1) |
53 | 41 |
|
54 | | - if not single_key and not multiple_keys: |
55 | | - print( |
56 | | - "❌ Error: Neither KEY_COLUMN_FOR_SINGLE_KEY nor KEY_COLUMNS_FOR_MULTIPLE_KEYS is set" |
57 | | - ) |
58 | | - print(" Please set one of them:") |
59 | | - print(" - KEY_COLUMN_FOR_SINGLE_KEY=id (for single primary key)") |
60 | | - print( |
61 | | - " - KEY_COLUMNS_FOR_MULTIPLE_KEYS=product_category,product_name (for composite primary key)" |
62 | | - ) |
63 | | - sys.exit(1) |
| 42 | + message_embeddings.export( |
| 43 | + "message_embeddings", |
| 44 | + cocoindex.targets.Postgres(), |
| 45 | + primary_key_fields=["id"], |
| 46 | + vector_indexes=[ |
| 47 | + cocoindex.VectorIndexDef( |
| 48 | + field_name="embedding", |
| 49 | + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, |
| 50 | + ) |
| 51 | + ], |
| 52 | + ) |
64 | 53 |
|
65 | | - if single_key: |
66 | | - # Single primary key |
67 | | - return [single_key.strip()] |
68 | | - else: |
69 | | - # Multiple primary keys (composite key) |
70 | | - return [col.strip() for col in multiple_keys.split(",")] |
71 | 54 |
|
| 55 | +@cocoindex.op.function() |
| 56 | +def calculate_total_value( |
| 57 | + price: float, |
| 58 | + amount: int, |
| 59 | +) -> float: |
| 60 | + return price * amount |
72 | 61 |
|
73 | | -def is_single_key() -> bool: |
74 | | - """ |
75 | | - Check if using single key or composite key configuration. |
76 | 62 |
|
77 | | - Returns: |
78 | | - bool: True if using single key, False if using composite key |
79 | | - """ |
80 | | - return bool(os.environ.get("KEY_COLUMN_FOR_SINGLE_KEY")) |
| 63 | +@cocoindex.op.function() |
| 64 | +def make_full_description( |
| 65 | + category: str, |
| 66 | + name: str, |
| 67 | + description: str, |
| 68 | +) -> str: |
| 69 | + return f"Category: {category}\nName: {name}\n\n{description}" |
81 | 70 |
|
82 | 71 |
|
83 | | -@cocoindex.flow_def(name="PostgresEmbedding") |
84 | | -def postgres_embedding_flow( |
| 72 | +@cocoindex.flow_def(name="PostgresProductEmbedding") |
| 73 | +def postgres_product_embedding_flow( |
85 | 74 | flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope |
86 | 75 | ) -> None: |
87 | 76 | """ |
88 | 77 | Define a flow that reads data from a PostgreSQL table, generates embeddings, |
89 | 78 | and stores them in another PostgreSQL table with pgvector. |
90 | 79 | """ |
91 | | - # Required environment variables |
92 | | - table_name = os.environ["TABLE_NAME"] |
93 | | - indexing_column = os.environ["INDEXING_COLUMN"] |
94 | | - |
95 | | - # Get key columns from environment |
96 | | - key_columns = get_key_columns_from_env() |
97 | | - |
98 | | - # Optional environment variables |
99 | | - ordinal_column = os.environ.get("ORDINAL_COLUMN") |
100 | | - |
101 | | - # Only include the data column - primary keys are automatically read by the PostgreSQL source |
102 | | - included_columns = [indexing_column] |
103 | | - |
104 | | - # Get source database URL for the Postgres source |
105 | | - source_db_url = os.environ.get("SOURCE_DATABASE_URL") |
106 | | - if not source_db_url: |
107 | | - print("❌ Error: SOURCE_DATABASE_URL environment variable is required") |
108 | | - print(" This should point to the database containing the source table") |
109 | | - sys.exit(1) |
110 | | - |
111 | | - # Create auth entry for the source database |
112 | | - source_db_conn = cocoindex.add_auth_entry( |
113 | | - "source_db_conn", cocoindex.DatabaseConnectionSpec(url=source_db_url) |
114 | | - ) |
115 | | - |
116 | | - # Read from source PostgreSQL table with only the specified columns |
117 | | - postgres_source_kwargs = { |
118 | | - "table_name": table_name, |
119 | | - "database": source_db_conn, |
120 | | - "included_columns": included_columns, |
121 | | - } |
122 | | - if ordinal_column: |
123 | | - postgres_source_kwargs["ordinal_column"] = ordinal_column |
124 | | - |
125 | | - data_scope["documents"] = flow_builder.add_source( |
126 | | - cocoindex.sources.Postgres(**postgres_source_kwargs) |
| 80 | + data_scope["products"] = flow_builder.add_source( |
| 81 | + cocoindex.sources.Postgres( |
| 82 | + table_name="source_products", |
| 83 | + database=cocoindex.add_transient_auth_entry( |
| 84 | + cocoindex.sources.DatabaseConnectionSpec( |
| 85 | + url=os.getenv("SOURCE_DATABASE_URL"), |
| 86 | + ) |
| 87 | + ), |
| 88 | + ) |
127 | 89 | ) |
128 | 90 |
|
129 | | - document_embeddings = data_scope.add_collector() |
130 | | - |
131 | | - with data_scope["documents"].row() as row: |
132 | | - # Use the indexing column for embedding generation |
133 | | - row["text_embedding"] = text_to_embedding(row[indexing_column]) |
134 | | - # Collect the data - include key columns and content |
135 | | - collect_data = { |
136 | | - "content": row[indexing_column], |
137 | | - "text_embedding": row["text_embedding"], |
138 | | - } |
139 | | - |
140 | | - # Add each key column as a separate field |
141 | | - for key_col in key_columns: |
142 | | - if is_single_key(): |
143 | | - collect_data[key_col] = row[key_col] |
144 | | - else: |
145 | | - collect_data[key_col] = row["_key"][key_col] |
146 | | - |
147 | | - document_embeddings.collect(**collect_data) |
| 91 | + product_embeddings = data_scope.add_collector() |
| 92 | + with data_scope["products"].row() as product: |
| 93 | + product["full_description"] = flow_builder.transform( |
| 94 | + make_full_description, |
| 95 | + product["_key"]["product_category"], |
| 96 | + product["_key"]["product_name"], |
| 97 | + product["description"], |
| 98 | + ) |
| 99 | + product["embedding"] = product["full_description"].transform( |
| 100 | + cocoindex.functions.SentenceTransformerEmbed( |
| 101 | + model="sentence-transformers/all-MiniLM-L6-v2" |
| 102 | + ) |
| 103 | + ) |
| 104 | + product_embeddings.collect( |
| 105 | + product_category=product["_key"]["product_category"], |
| 106 | + product_name=product["_key"]["product_name"], |
| 107 | + description=product["description"], |
| 108 | + price=product["price"], |
| 109 | + amount=product["amount"], |
| 110 | + embedding=product["embedding"], |
| 111 | + ) |
148 | 112 |
|
149 | | - document_embeddings.export( |
150 | | - "document_embeddings", |
| 113 | + product_embeddings.export( |
| 114 | + "product_embeddings", |
151 | 115 | cocoindex.targets.Postgres(), |
152 | | - primary_key_fields=key_columns, |
| 116 | + primary_key_fields=["product_category", "product_name"], |
153 | 117 | vector_indexes=[ |
154 | 118 | cocoindex.VectorIndexDef( |
155 | | - field_name="text_embedding", |
| 119 | + field_name="embedding", |
156 | 120 | metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, |
157 | 121 | ) |
158 | 122 | ], |
159 | 123 | ) |
160 | | - |
161 | | - |
162 | | -def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]: |
163 | | - # Get the table name, for the export target in the postgres_embedding_flow above. |
164 | | - table_name = cocoindex.utils.get_target_default_name( |
165 | | - postgres_embedding_flow, "document_embeddings" |
166 | | - ) |
167 | | - |
168 | | - # Get key columns configuration |
169 | | - key_columns = get_key_columns_from_env() |
170 | | - # Build SELECT clause with all key columns |
171 | | - key_columns_select = ", ".join(key_columns) |
172 | | - |
173 | | - # Evaluate the transform flow defined above with the input query, to get the embedding. |
174 | | - query_vector = text_to_embedding.eval(query) |
175 | | - # Run the query and get the results. |
176 | | - with pool.connection() as conn: |
177 | | - register_vector(conn) |
178 | | - with conn.cursor() as cur: |
179 | | - cur.execute( |
180 | | - f""" |
181 | | - SELECT content, text_embedding <=> %s::vector AS distance, {key_columns_select} |
182 | | - FROM {table_name} ORDER BY distance LIMIT %s |
183 | | - """, |
184 | | - (query_vector, top_k), |
185 | | - ) |
186 | | - results = [] |
187 | | - for row in cur.fetchall(): |
188 | | - result = { |
189 | | - "content": row[0], |
190 | | - "score": 1.0 - row[1], |
191 | | - "key": "__".join(str(x) for x in row[2:]), |
192 | | - } |
193 | | - results.append(result) |
194 | | - return results |
195 | | - |
196 | | - |
197 | | -def _main() -> None: |
198 | | - # Initialize the database connection pool for CocoIndex database (where embeddings are stored) |
199 | | - cocoindex_db_url = os.getenv("COCOINDEX_DATABASE_URL") |
200 | | - if not cocoindex_db_url: |
201 | | - print("❌ Error: COCOINDEX_DATABASE_URL environment variable is required") |
202 | | - print(" This should point to the database where embeddings will be stored") |
203 | | - sys.exit(1) |
204 | | - |
205 | | - pool = ConnectionPool(cocoindex_db_url) |
206 | | - |
207 | | - postgres_embedding_flow.setup() |
208 | | - with cocoindex.FlowLiveUpdater(postgres_embedding_flow) as updater: |
209 | | - # Run queries in a loop to demonstrate the query capabilities. |
210 | | - while True: |
211 | | - query = input("Enter search query (or Enter to quit): ") |
212 | | - if query == "": |
213 | | - break |
214 | | - # Run the query function with the database connection pool and the query. |
215 | | - results = search(pool, query) |
216 | | - print("\nSearch results:") |
217 | | - for result in results: |
218 | | - print( |
219 | | - f"[{result['score']:.3f}] {result['content']} key: {result['key']}" |
220 | | - ) |
221 | | - print("---") |
222 | | - print() |
223 | | - |
224 | | - |
225 | | -if __name__ == "__main__": |
226 | | - load_dotenv() |
227 | | - cocoindex.init() |
228 | | - _main() |
0 commit comments