33import os
44from typing import Any , cast
55
6- from chromadb import GetResult , Where
6+ from chromadb import Where
77from chromadb .api .models .AsyncCollection import AsyncCollection
8- from chromadb .api .types import IncludeEnum
8+ from chromadb .api .types import IncludeEnum , QueryResult
99from chromadb .errors import InvalidCollectionException , InvalidDimensionException
10+ from tree_sitter import Point
1011
11- from vectorcode .chunking import StringChunker
12+ from vectorcode .chunking import Chunk , StringChunker
1213from vectorcode .cli_utils import (
1314 Config ,
1415 QueryInclude ,
2223 get_embedding_function ,
2324 verify_ef ,
2425)
26+ from vectorcode .subcommands .query import types as vectorcode_types
2527from vectorcode .subcommands .query .reranker import (
2628 RerankerError ,
2729 get_reranker ,
3032logger = logging .getLogger (name = __name__ )
3133
3234
35+ def convert_query_results (
36+ chroma_result : QueryResult , queries : list [str ]
37+ ) -> list [vectorcode_types .QueryResult ]:
38+ """Convert chromadb query result to in-house query results"""
39+ assert chroma_result ["documents" ] is not None
40+ assert chroma_result ["distances" ] is not None
41+ assert chroma_result ["metadatas" ] is not None
42+ assert chroma_result ["ids" ] is not None
43+
44+ chroma_results_list : list [vectorcode_types .QueryResult ] = []
45+ for q_i in range (len (queries )):
46+ q = queries [q_i ]
47+ documents = chroma_result ["documents" ][q_i ]
48+ distances = chroma_result ["distances" ][q_i ]
49+ metadatas = chroma_result ["metadatas" ][q_i ]
50+ ids = chroma_result ["ids" ][q_i ]
51+ for doc , dist , meta , _id in zip (documents , distances , metadatas , ids ):
52+ chunk = Chunk (text = doc , id = _id )
53+ if meta .get ("start" ):
54+ chunk .start = Point (int (meta .get ("start" , 0 )), 0 )
55+ if meta .get ("end" ):
56+ chunk .end = Point (int (meta .get ("end" , 0 )), 0 )
57+ if meta .get ("path" ):
58+ chunk .path = str (meta ["path" ])
59+ chroma_results_list .append (
60+ vectorcode_types .QueryResult (
61+ chunk = chunk ,
62+ path = str (meta .get ("path" , "" )),
63+ query = (q ,),
64+ scores = (- dist ,),
65+ )
66+ )
67+ return chroma_results_list
68+
69+
3370async def get_query_result_files (
3471 collection : AsyncCollection , configs : Config
35- ) -> list [str ]:
72+ ) -> list [str | Chunk ]:
3673 query_chunks = []
37- if configs .query :
38- chunker = StringChunker (configs )
39- for q in configs .query :
40- query_chunks .extend (str (i ) for i in chunker .chunk (q ))
74+ assert configs .query , "Query messages cannot be empty."
75+ chunker = StringChunker (configs )
76+ for q in configs .query :
77+ query_chunks .extend (str (i ) for i in chunker .chunk (q ))
4178
4279 configs .query_exclude = [
4380 expand_path (i , True )
@@ -70,7 +107,7 @@ async def get_query_result_files(
70107 query_embeddings = get_embedding_function (configs )(query_chunks )
71108 if isinstance (configs .embedding_dims , int ) and configs .embedding_dims > 0 :
72109 query_embeddings = [e [: configs .embedding_dims ] for e in query_embeddings ]
73- results = await collection .query (
110+ chroma_query_results : QueryResult = await collection .query (
74111 query_embeddings = query_embeddings ,
75112 n_results = num_query ,
76113 include = [
@@ -85,69 +122,51 @@ async def get_query_result_files(
85122 return []
86123
87124 reranker = get_reranker (configs )
88- return await reranker .rerank (results )
125+ return await reranker .rerank (
126+ convert_query_results (chroma_query_results , configs .query )
127+ )
89128
90129
91130async def build_query_results (
92131 collection : AsyncCollection , configs : Config
93132) -> list [dict [str , str | int ]]:
94- structured_result = []
95- for identifier in await get_query_result_files (collection , configs ):
96- if os .path .isfile (identifier ):
97- if configs .use_absolute_path :
98- output_path = os .path .abspath (identifier )
99- else :
100- output_path = os .path .relpath (identifier , configs .project_root )
101- full_result = {"path" : output_path }
102- with open (identifier ) as fin :
103- document = fin .read ()
104- full_result ["document" ] = document
133+ assert configs .project_root
105134
106- structured_result .append (
107- {str (key ): full_result [str (key )] for key in configs .include }
108- )
109- elif QueryInclude .chunk in configs .include :
110- chunks : GetResult = await collection .get (
111- identifier , include = [IncludeEnum .metadatas , IncludeEnum .documents ]
112- )
113- meta = chunks .get (
114- "metadatas" ,
115- )
116- if meta is not None and len (meta ) != 0 :
117- chunk_texts = chunks .get ("documents" )
118- assert chunk_texts is not None , (
119- "QueryResult does not contain `documents`!"
120- )
121- full_result : dict [str , str | int ] = {
122- "chunk" : str (chunk_texts [0 ]),
123- "chunk_id" : identifier ,
124- }
125- if meta [0 ].get ("start" ) is not None and meta [0 ].get ("end" ) is not None :
126- path = str (meta [0 ].get ("path" ))
127- with open (path ) as fin :
128- start : int = int (meta [0 ]["start" ])
129- end : int = int (meta [0 ]["end" ])
130- full_result ["chunk" ] = "" .join (fin .readlines ()[start : end + 1 ])
131- full_result ["start_line" ] = start
132- full_result ["end_line" ] = end
133- if QueryInclude .path in configs .include :
134- full_result ["path" ] = str (
135- meta [0 ]["path" ]
136- if configs .use_absolute_path
137- else os .path .relpath (
138- str (meta [0 ]["path" ]), str (configs .project_root )
139- )
140- )
141-
142- structured_result .append (full_result )
143- else : # pragma: nocover
144- logger .error (
145- "This collection doesn't support chunk-mode output because it lacks the necessary metadata. Please re-vectorise it." ,
146- )
135+ def make_output_path (path : str , absolute : bool ) -> str :
136+ if absolute :
137+ if os .path .isabs (path ):
138+ return path
139+ return os .path .abspath (os .path .join (str (configs .project_root ), path ))
140+ else :
141+ rel_path = os .path .relpath (path , configs .project_root )
142+ if isinstance (rel_path , bytes ): # pragma: nocover
143+ # for some reasons, some python versions report that `os.path.relpath` returns a string.
144+ rel_path = rel_path .decode ()
145+ return rel_path
147146
147+ structured_result = []
148+ for res in await get_query_result_files (collection , configs ):
149+ if isinstance (res , str ):
150+ output_path = make_output_path (res , configs .use_absolute_path )
151+ io_path = make_output_path (res , True )
152+ if not os .path .isfile (io_path ):
153+ logger .warning (f"{ io_path } is no longer a valid file." )
154+ continue
155+ with open (io_path ) as fin :
156+ structured_result .append ({"path" : output_path , "document" : fin .read ()})
148157 else :
149- logger .warning (
150- f"{ identifier } is no longer a valid file! Please re-run vectorcode vectorise to refresh the database." ,
158+ res = cast (Chunk , res )
159+ assert res .path , f"{ res } has no `path` attribute."
160+ structured_result .append (
161+ {
162+ "path" : make_output_path (res .path , configs .use_absolute_path )
163+ if res .path is not None
164+ else None ,
165+ "chunk" : res .text ,
166+ "start_line" : res .start .row if res .start is not None else None ,
167+ "end_line" : res .end .row if res .end is not None else None ,
168+ "chunk_id" : res .id ,
169+ }
151170 )
152171 for result in structured_result :
153172 if result .get ("path" ) is not None :
0 commit comments