Skip to content

Commit a2a5c00

Browse files
authored
upgrade query handler for aws example (#567)
Update main.py
1 parent e1cc35e commit a2a5c00

File tree

1 file changed

+54
-32
lines changed
  • examples/amazon_s3_embedding

1 file changed

+54
-32
lines changed

examples/amazon_s3_embedding/main.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
11
from dotenv import load_dotenv
2-
2+
from psycopg_pool import ConnectionPool
33
import cocoindex
44
import os
55

66

7+
@cocoindex.transform_flow()
8+
def text_to_embedding(
9+
text: cocoindex.DataSlice[str],
10+
) -> cocoindex.DataSlice[list[float]]:
11+
"""
12+
Embed the text using a SentenceTransformer model.
13+
This is a shared logic between indexing and querying, so extract it as a function.
14+
"""
15+
return text.transform(
16+
cocoindex.functions.SentenceTransformerEmbed(
17+
model="sentence-transformers/all-MiniLM-L6-v2"
18+
)
19+
)
20+
21+
722
@cocoindex.flow_def(name="AmazonS3TextEmbedding")
823
def amazon_s3_text_embedding_flow(
924
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
@@ -19,7 +34,7 @@ def amazon_s3_text_embedding_flow(
1934
cocoindex.sources.AmazonS3(
2035
bucket_name=bucket_name,
2136
prefix=prefix,
22-
included_patterns=["*.md", "*.txt", "*.docx"],
37+
included_patterns=["*.md", "*.mdx", "*.txt", "*.docx"],
2338
binary=False,
2439
sqs_queue_url=sqs_queue_url,
2540
)
@@ -36,11 +51,7 @@ def amazon_s3_text_embedding_flow(
3651
)
3752

3853
with doc["chunks"].row() as chunk:
39-
chunk["embedding"] = chunk["text"].transform(
40-
cocoindex.functions.SentenceTransformerEmbed(
41-
model="sentence-transformers/all-MiniLM-L6-v2"
42-
)
43-
)
54+
chunk["embedding"] = text_to_embedding(chunk["text"])
4455
doc_embeddings.collect(
4556
filename=doc["filename"],
4657
location=chunk["location"],
@@ -61,34 +72,45 @@ def amazon_s3_text_embedding_flow(
6172
)
6273

6374

64-
query_handler = cocoindex.query.SimpleSemanticsQueryHandler(
65-
name="SemanticsSearch",
66-
flow=amazon_s3_text_embedding_flow,
67-
target_name="doc_embeddings",
68-
query_transform_flow=lambda text: text.transform(
69-
cocoindex.functions.SentenceTransformerEmbed(
70-
model="sentence-transformers/all-MiniLM-L6-v2"
71-
)
72-
),
73-
default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
74-
)
75+
def search(pool: ConnectionPool, query: str, top_k: int = 5):
76+
# Get the table name, for the export target in the amazon_s3_text_embedding_flow above.
77+
table_name = cocoindex.utils.get_target_storage_default_name(
78+
amazon_s3_text_embedding_flow, "doc_embeddings"
79+
)
80+
# Evaluate the transform flow defined above with the input query, to get the embedding.
81+
query_vector = text_to_embedding.eval(query)
82+
# Run the query and get the results.
83+
with pool.connection() as conn:
84+
with conn.cursor() as cur:
85+
cur.execute(
86+
f"""
87+
SELECT filename, text, embedding <=> %s::vector AS distance
88+
FROM {table_name} ORDER BY distance LIMIT %s
89+
""",
90+
(query_vector, top_k),
91+
)
92+
return [
93+
{"filename": row[0], "text": row[1], "score": 1.0 - row[2]}
94+
for row in cur.fetchall()
95+
]
7596

7697

7798
def _main():
78-
# Use a `FlowLiveUpdater` to keep the flow data updated.
79-
with cocoindex.FlowLiveUpdater(amazon_s3_text_embedding_flow):
80-
# Run queries in a loop to demonstrate the query capabilities.
81-
while True:
82-
query = input("Enter search query (or Enter to quit): ")
83-
if query == "":
84-
break
85-
results, _ = query_handler.search(query, 10)
86-
print("\nSearch results:")
87-
for result in results:
88-
print(f"[{result.score:.3f}] {result.data['filename']}")
89-
print(f" {result.data['text']}")
90-
print("---")
91-
print()
99+
# Initialize the database connection pool.
100+
pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL"))
101+
# Run queries in a loop to demonstrate the query capabilities.
102+
while True:
103+
query = input("Enter search query (or Enter to quit): ")
104+
if query == "":
105+
break
106+
# Run the query function with the database connection pool and the query.
107+
results = search(pool, query)
108+
print("\nSearch results:")
109+
for result in results:
110+
print(f"[{result['score']:.3f}] {result['filename']}")
111+
print(f" {result['text']}")
112+
print("---")
113+
print()
92114

93115

94116
if __name__ == "__main__":

0 commit comments

Comments
 (0)