Skip to content

Commit 8fc19a4

Browse files
committed
examples: update examples for query handler
1 parent 5fa0771 commit 8fc19a4

File tree

3 files changed

+118
-44
lines changed

3 files changed

+118
-44
lines changed

examples/code_embedding/main.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from psycopg_pool import ConnectionPool
33
from pgvector.psycopg import register_vector
44
from typing import Any
5+
import functools
56
import cocoindex
67
import os
78
from numpy.typing import NDArray
@@ -84,52 +85,74 @@ def code_embedding_flow(
8485
)
8586

8687

87-
def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
88+
@functools.cache
89+
def connection_pool() -> ConnectionPool:
90+
"""
91+
Get a connection pool to the database.
92+
"""
93+
return ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"])
94+
95+
96+
TOP_K = 5
97+
98+
99+
# Declaring it ss a query handler, so that you can easily run queries in CocoInsight.
100+
@code_embedding_flow.query_handler(
101+
result_fields=cocoindex.QueryHandlerResultFields(
102+
embedding=["embedding"], score="score"
103+
)
104+
)
105+
def search(query: str) -> cocoindex.QueryOutput:
88106
# Get the table name, for the export target in the code_embedding_flow above.
89107
table_name = cocoindex.utils.get_target_default_name(
90108
code_embedding_flow, "code_embeddings"
91109
)
92110
# Evaluate the transform flow defined above with the input query, to get the embedding.
93111
query_vector = code_to_embedding.eval(query)
94112
# Run the query and get the results.
95-
with pool.connection() as conn:
113+
with connection_pool().connection() as conn:
96114
register_vector(conn)
97115
with conn.cursor() as cur:
98116
cur.execute(
99117
f"""
100-
SELECT filename, code, embedding <=> %s AS distance, start, "end"
118+
SELECT filename, code, embedding, embedding <=> %s AS distance, start, "end"
101119
FROM {table_name} ORDER BY distance LIMIT %s
102120
""",
103-
(query_vector, top_k),
121+
(query_vector, TOP_K),
122+
)
123+
return cocoindex.QueryOutput(
124+
query_info=cocoindex.QueryInfo(
125+
embedding=query_vector,
126+
similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
127+
),
128+
results=[
129+
{
130+
"filename": row[0],
131+
"code": row[1],
132+
"embedding": row[2],
133+
"score": 1.0 - row[3],
134+
"start": row[4],
135+
"end": row[5],
136+
}
137+
for row in cur.fetchall()
138+
],
104139
)
105-
return [
106-
{
107-
"filename": row[0],
108-
"code": row[1],
109-
"score": 1.0 - row[2],
110-
"start": row[3],
111-
"end": row[4],
112-
}
113-
for row in cur.fetchall()
114-
]
115140

116141

117142
def _main() -> None:
118143
# Make sure the flow is built and up-to-date.
119144
stats = code_embedding_flow.update()
120145
print("Updated index: ", stats)
121146

122-
# Initialize the database connection pool.
123-
pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL"))
124147
# Run queries in a loop to demonstrate the query capabilities.
125148
while True:
126149
query = input("Enter search query (or Enter to quit): ")
127150
if query == "":
128151
break
129152
# Run the query function with the database connection pool and the query.
130-
results = search(pool, query)
153+
query_output = search(query)
131154
print("\nSearch results:")
132-
for result in results:
155+
for result in query_output.results:
133156
print(
134157
f"[{result['score']:.3f}] {result['filename']} (L{result['start']['line']}-L{result['end']['line']})"
135158
)

examples/text_embedding/main.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any
55
import cocoindex
66
import os
7+
import functools
78
from numpy.typing import NDArray
89
import numpy as np
910
from datetime import timedelta
@@ -74,42 +75,65 @@ def text_embedding_flow(
7475
)
7576

7677

77-
def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
78+
@functools.cache
79+
def connection_pool() -> ConnectionPool:
80+
"""
81+
Get a connection pool to the database.
82+
"""
83+
return ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"])
84+
85+
86+
TOP_K = 5
87+
88+
89+
# Declaring it ss a query handler, so that you can easily run queries in CocoInsight.
90+
@text_embedding_flow.query_handler(
91+
result_fields=cocoindex.QueryHandlerResultFields(
92+
embedding=["embedding"],
93+
score="score",
94+
),
95+
)
96+
def search(query: str) -> cocoindex.QueryOutput:
7897
# Get the table name, for the export target in the text_embedding_flow above.
7998
table_name = cocoindex.utils.get_target_default_name(
8099
text_embedding_flow, "doc_embeddings"
81100
)
82101
# Evaluate the transform flow defined above with the input query, to get the embedding.
83102
query_vector = text_to_embedding.eval(query)
84103
# Run the query and get the results.
85-
with pool.connection() as conn:
104+
with connection_pool().connection() as conn:
86105
register_vector(conn)
87106
with conn.cursor() as cur:
88107
cur.execute(
89108
f"""
90109
SELECT filename, text, embedding <=> %s AS distance
91110
FROM {table_name} ORDER BY distance LIMIT %s
92111
""",
93-
(query_vector, top_k),
112+
(query_vector, TOP_K),
94113
)
95-
return [
114+
results = [
96115
{"filename": row[0], "text": row[1], "score": 1.0 - row[2]}
97116
for row in cur.fetchall()
98117
]
118+
return cocoindex.QueryOutput(
119+
results=results,
120+
query_info=cocoindex.QueryInfo(
121+
embedding=query_vector,
122+
similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
123+
),
124+
)
99125

100126

101127
def _main() -> None:
102-
# Initialize the database connection pool.
103-
pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL"))
104128
# Run queries in a loop to demonstrate the query capabilities.
105129
while True:
106130
query = input("Enter search query (or Enter to quit): ")
107131
if query == "":
108132
break
109133
# Run the query function with the database connection pool and the query.
110-
results = search(pool, query)
134+
query_output = search(query)
111135
print("\nSearch results:")
112-
for result in results:
136+
for result in query_output.results:
113137
print(f"[{result['score']:.3f}] {result['filename']}")
114138
print(f" {result['text']}")
115139
print("---")

examples/text_embedding_qdrant/main.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import functools
12
from dotenv import load_dotenv
23
from qdrant_client import QdrantClient
34
import cocoindex
@@ -61,32 +62,58 @@ def text_embedding_flow(
6162
)
6263

6364

64-
def _main() -> None:
65-
# Initialize Qdrant client
66-
client = QdrantClient(url=QDRANT_URL, prefer_grpc=True)
65+
@functools.cache
66+
def get_qdrant_client() -> QdrantClient:
67+
return QdrantClient(url=QDRANT_URL, prefer_grpc=True)
68+
69+
70+
@text_embedding_flow.query_handler(
71+
result_fields=cocoindex.QueryHandlerResultFields(
72+
embedding=["embedding"],
73+
score="score",
74+
),
75+
)
76+
def search(query: str) -> cocoindex.QueryOutput:
77+
client = get_qdrant_client()
6778

79+
# Get the embedding for the query
80+
query_embedding = text_to_embedding.eval(query)
81+
82+
search_results = client.search(
83+
collection_name=QDRANT_COLLECTION,
84+
query_vector=("text_embedding", query_embedding),
85+
limit=10,
86+
)
87+
return cocoindex.QueryOutput(
88+
results=[
89+
{
90+
"filename": result.payload["filename"],
91+
"text": result.payload["text"],
92+
"embedding": result.vector,
93+
"score": result.score,
94+
}
95+
for result in search_results
96+
],
97+
query_info=cocoindex.QueryInfo(
98+
embedding=query_embedding,
99+
similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
100+
),
101+
)
102+
103+
104+
def _main() -> None:
68105
# Run queries in a loop to demonstrate the query capabilities.
69106
while True:
70107
query = input("Enter search query (or Enter to quit): ")
71108
if query == "":
72109
break
73110

74-
# Get the embedding for the query
75-
query_embedding = text_to_embedding.eval(query)
76-
77-
search_results = client.search(
78-
collection_name=QDRANT_COLLECTION,
79-
query_vector=("text_embedding", query_embedding),
80-
limit=10,
81-
)
111+
# Run the query function with the database connection pool and the query.
112+
query_output = search(query)
82113
print("\nSearch results:")
83-
for result in search_results:
84-
score = result.score
85-
payload = result.payload
86-
if payload is None:
87-
continue
88-
print(f"[{score:.3f}] {payload['filename']}")
89-
print(f" {payload['text']}")
114+
for result in query_output.results:
115+
print(f"[{result['score']:.3f}] {result['filename']}")
116+
print(f" {result['text']}")
90117
print("---")
91118
print()
92119

0 commit comments

Comments
 (0)