|
2 | 2 | from psycopg_pool import ConnectionPool |
3 | 3 | from pgvector.psycopg import register_vector |
4 | 4 | from typing import Any |
| 5 | +import functools |
5 | 6 | import cocoindex |
6 | 7 | import os |
7 | 8 | from numpy.typing import NDArray |
@@ -84,52 +85,74 @@ def code_embedding_flow( |
84 | 85 | ) |
85 | 86 |
|
86 | 87 |
|
87 | | -def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]: |
| 88 | +@functools.cache |
| 89 | +def connection_pool() -> ConnectionPool: |
| 90 | + """ |
| 91 | + Get a connection pool to the database. |
| 92 | + """ |
| 93 | + return ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"]) |
| 94 | + |
| 95 | + |
| 96 | +TOP_K = 5 |
| 97 | + |
| 98 | + |
| 99 | +# Declaring it ss a query handler, so that you can easily run queries in CocoInsight. |
| 100 | +@code_embedding_flow.query_handler( |
| 101 | + result_fields=cocoindex.QueryHandlerResultFields( |
| 102 | + embedding=["embedding"], score="score" |
| 103 | + ) |
| 104 | +) |
| 105 | +def search(query: str) -> cocoindex.QueryOutput: |
88 | 106 | # Get the table name, for the export target in the code_embedding_flow above. |
89 | 107 | table_name = cocoindex.utils.get_target_default_name( |
90 | 108 | code_embedding_flow, "code_embeddings" |
91 | 109 | ) |
92 | 110 | # Evaluate the transform flow defined above with the input query, to get the embedding. |
93 | 111 | query_vector = code_to_embedding.eval(query) |
94 | 112 | # Run the query and get the results. |
95 | | - with pool.connection() as conn: |
| 113 | + with connection_pool().connection() as conn: |
96 | 114 | register_vector(conn) |
97 | 115 | with conn.cursor() as cur: |
98 | 116 | cur.execute( |
99 | 117 | f""" |
100 | | - SELECT filename, code, embedding <=> %s AS distance, start, "end" |
| 118 | + SELECT filename, code, embedding, embedding <=> %s AS distance, start, "end" |
101 | 119 | FROM {table_name} ORDER BY distance LIMIT %s |
102 | 120 | """, |
103 | | - (query_vector, top_k), |
| 121 | + (query_vector, TOP_K), |
| 122 | + ) |
| 123 | + return cocoindex.QueryOutput( |
| 124 | + query_info=cocoindex.QueryInfo( |
| 125 | + embedding=query_vector, |
| 126 | + similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, |
| 127 | + ), |
| 128 | + results=[ |
| 129 | + { |
| 130 | + "filename": row[0], |
| 131 | + "code": row[1], |
| 132 | + "embedding": row[2], |
| 133 | + "score": 1.0 - row[3], |
| 134 | + "start": row[4], |
| 135 | + "end": row[5], |
| 136 | + } |
| 137 | + for row in cur.fetchall() |
| 138 | + ], |
104 | 139 | ) |
105 | | - return [ |
106 | | - { |
107 | | - "filename": row[0], |
108 | | - "code": row[1], |
109 | | - "score": 1.0 - row[2], |
110 | | - "start": row[3], |
111 | | - "end": row[4], |
112 | | - } |
113 | | - for row in cur.fetchall() |
114 | | - ] |
115 | 140 |
|
116 | 141 |
|
117 | 142 | def _main() -> None: |
118 | 143 | # Make sure the flow is built and up-to-date. |
119 | 144 | stats = code_embedding_flow.update() |
120 | 145 | print("Updated index: ", stats) |
121 | 146 |
|
122 | | - # Initialize the database connection pool. |
123 | | - pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL")) |
124 | 147 | # Run queries in a loop to demonstrate the query capabilities. |
125 | 148 | while True: |
126 | 149 | query = input("Enter search query (or Enter to quit): ") |
127 | 150 | if query == "": |
128 | 151 | break |
129 | 152 | # Run the query function with the database connection pool and the query. |
130 | | - results = search(pool, query) |
| 153 | + query_output = search(query) |
131 | 154 | print("\nSearch results:") |
132 | | - for result in results: |
| 155 | + for result in query_output.results: |
133 | 156 | print( |
134 | 157 | f"[{result['score']:.3f}] {result['filename']} (L{result['start']['line']}-L{result['end']['line']})" |
135 | 158 | ) |
|
0 commit comments