Skip to content

Commit 4580c45

Browse files
authored
refactor(cli): Implement QueryResult type. (#280)
* feat(cli): Implement `QueryResult` type. * refactor(cli): Use in-house `QueryResult` in the query subcommand and the rerankers. * fix(cli): test coverage. * make basedpyright happy. * test coverage * typo. * rename parameter * refactor(cli): make function signatures more consistent. * fix(cli): safeguard the output processing. * improve coverage * tests for `Chunk.export_dict`/`chunks` pipe mode
1 parent bf5f94c commit 4580c45

File tree

13 files changed

+545
-352
lines changed

13 files changed

+545
-352
lines changed

src/vectorcode/chunking.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,36 @@ class Chunk:
2525
"""
2626

2727
text: str
28-
start: Point
29-
end: Point
28+
start: Point | None = None
29+
end: Point | None = None
30+
path: str | None = None
31+
id: str | None = None
3032

3133
def __str__(self):
3234
return self.text
3335

36+
def __hash__(self) -> int:
37+
return hash(f"VectorCodeChunk_{self.path}({self.start}:{self.end}@{self.text})")
38+
3439
def export_dict(self):
40+
d: dict[str, str | dict[str, int]] = {"text": self.text}
3541
if self.start is not None:
36-
return {
37-
"text": self.text,
38-
"start": {"row": self.start.row, "column": self.start.column},
39-
"end": {"row": self.end.row, "column": self.end.column},
40-
}
41-
else:
42-
return {"text": self.text}
42+
d.update(
43+
{
44+
"start": {"row": self.start.row, "column": self.start.column},
45+
}
46+
)
47+
if self.end is not None:
48+
d.update(
49+
{
50+
"end": {"row": self.end.row, "column": self.end.column},
51+
}
52+
)
53+
if self.path:
54+
d["path"] = self.path
55+
if self.id:
56+
d["chunk_id"] = self.id
57+
return d
4358

4459

4560
@dataclass
@@ -129,7 +144,7 @@ def chunk(
129144
) -> Generator[Chunk, None, None]:
130145
logger.info("Started chunking %s using FileChunker.", data.name)
131146
lines = data.readlines()
132-
if len(lines) == 0:
147+
if len(lines) == 0: # pragma: nocover
133148
return
134149
if (
135150
self.config.chunk_size < 0

src/vectorcode/mcp_main.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import sys
6+
import traceback
67
from dataclasses import dataclass
78
from pathlib import Path
89
from typing import Optional, cast
@@ -160,14 +161,17 @@ async def vectorise_files(paths: list[str], project_root: str) -> dict[str, int]
160161
await remove_orphanes(collection, collection_lock, stats, stats_lock)
161162

162163
return stats.to_dict()
163-
except Exception as e:
164-
logger.error("Failed to access collection at %s", project_root)
165-
raise McpError(
166-
ErrorData(
167-
code=1,
168-
message=f"{e.__class__.__name__}: Failed to create the collection at {project_root}.",
169-
)
170-
)
164+
except Exception as e: # pragma: nocover
165+
if isinstance(e, McpError):
166+
logger.error("Failed to access collection at %s", project_root)
167+
raise
168+
else:
169+
raise McpError(
170+
ErrorData(
171+
code=1,
172+
message="\n".join(traceback.format_exception(e)),
173+
)
174+
) from e
171175

172176

173177
async def query_tool(
@@ -211,24 +215,28 @@ async def query_tool(
211215
configs=query_config,
212216
)
213217
results: list[str] = []
214-
for path in result_paths:
215-
if os.path.isfile(path):
216-
with open(path) as fin:
217-
rel_path = os.path.relpath(path, config.project_root)
218-
results.append(
219-
f"<path>{rel_path}</path>\n<content>{fin.read()}</content>",
220-
)
218+
for result in result_paths:
219+
if isinstance(result, str):
220+
if os.path.isfile(result):
221+
with open(result) as fin:
222+
rel_path = os.path.relpath(result, config.project_root)
223+
results.append(
224+
f"<path>{rel_path}</path>\n<content>{fin.read()}</content>",
225+
)
221226
logger.info("Retrieved the following files: %s", result_paths)
222227
return results
223228

224-
except Exception as e:
225-
logger.error("Failed to access collection at %s", project_root)
226-
raise McpError(
227-
ErrorData(
228-
code=1,
229-
message=f"{e.__class__.__name__}: Failed to access the collection at {project_root}. Use `list_collections` tool to get a list of valid paths for this field.",
230-
)
231-
)
229+
except Exception as e: # pragma: nocover
230+
if isinstance(e, McpError):
231+
logger.error("Failed to access collection at %s", project_root)
232+
raise
233+
else:
234+
raise McpError(
235+
ErrorData(
236+
code=1,
237+
message="\n".join(traceback.format_exception(e)),
238+
)
239+
) from e
232240

233241

234242
async def ls_files(project_root: str) -> list[str]:

src/vectorcode/subcommands/query/__init__.py

Lines changed: 83 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import os
44
from typing import Any, cast
55

6-
from chromadb import GetResult, Where
6+
from chromadb import Where
77
from chromadb.api.models.AsyncCollection import AsyncCollection
8-
from chromadb.api.types import IncludeEnum
8+
from chromadb.api.types import IncludeEnum, QueryResult
99
from chromadb.errors import InvalidCollectionException, InvalidDimensionException
10+
from tree_sitter import Point
1011

11-
from vectorcode.chunking import StringChunker
12+
from vectorcode.chunking import Chunk, StringChunker
1213
from vectorcode.cli_utils import (
1314
Config,
1415
QueryInclude,
@@ -22,6 +23,7 @@
2223
get_embedding_function,
2324
verify_ef,
2425
)
26+
from vectorcode.subcommands.query import types as vectorcode_types
2527
from vectorcode.subcommands.query.reranker import (
2628
RerankerError,
2729
get_reranker,
@@ -30,14 +32,49 @@
3032
logger = logging.getLogger(name=__name__)
3133

3234

35+
def convert_query_results(
36+
chroma_result: QueryResult, queries: list[str]
37+
) -> list[vectorcode_types.QueryResult]:
38+
"""Convert chromadb query result to in-house query results"""
39+
assert chroma_result["documents"] is not None
40+
assert chroma_result["distances"] is not None
41+
assert chroma_result["metadatas"] is not None
42+
assert chroma_result["ids"] is not None
43+
44+
chroma_results_list: list[vectorcode_types.QueryResult] = []
45+
for q_i in range(len(queries)):
46+
q = queries[q_i]
47+
documents = chroma_result["documents"][q_i]
48+
distances = chroma_result["distances"][q_i]
49+
metadatas = chroma_result["metadatas"][q_i]
50+
ids = chroma_result["ids"][q_i]
51+
for doc, dist, meta, _id in zip(documents, distances, metadatas, ids):
52+
chunk = Chunk(text=doc, id=_id)
53+
if meta.get("start"):
54+
chunk.start = Point(int(meta.get("start", 0)), 0)
55+
if meta.get("end"):
56+
chunk.end = Point(int(meta.get("end", 0)), 0)
57+
if meta.get("path"):
58+
chunk.path = str(meta["path"])
59+
chroma_results_list.append(
60+
vectorcode_types.QueryResult(
61+
chunk=chunk,
62+
path=str(meta.get("path", "")),
63+
query=(q,),
64+
scores=(-dist,),
65+
)
66+
)
67+
return chroma_results_list
68+
69+
3370
async def get_query_result_files(
3471
collection: AsyncCollection, configs: Config
35-
) -> list[str]:
72+
) -> list[str | Chunk]:
3673
query_chunks = []
37-
if configs.query:
38-
chunker = StringChunker(configs)
39-
for q in configs.query:
40-
query_chunks.extend(str(i) for i in chunker.chunk(q))
74+
assert configs.query, "Query messages cannot be empty."
75+
chunker = StringChunker(configs)
76+
for q in configs.query:
77+
query_chunks.extend(str(i) for i in chunker.chunk(q))
4178

4279
configs.query_exclude = [
4380
expand_path(i, True)
@@ -70,7 +107,7 @@ async def get_query_result_files(
70107
query_embeddings = get_embedding_function(configs)(query_chunks)
71108
if isinstance(configs.embedding_dims, int) and configs.embedding_dims > 0:
72109
query_embeddings = [e[: configs.embedding_dims] for e in query_embeddings]
73-
results = await collection.query(
110+
chroma_query_results: QueryResult = await collection.query(
74111
query_embeddings=query_embeddings,
75112
n_results=num_query,
76113
include=[
@@ -85,69 +122,51 @@ async def get_query_result_files(
85122
return []
86123

87124
reranker = get_reranker(configs)
88-
return await reranker.rerank(results)
125+
return await reranker.rerank(
126+
convert_query_results(chroma_query_results, configs.query)
127+
)
89128

90129

91130
async def build_query_results(
92131
collection: AsyncCollection, configs: Config
93132
) -> list[dict[str, str | int]]:
94-
structured_result = []
95-
for identifier in await get_query_result_files(collection, configs):
96-
if os.path.isfile(identifier):
97-
if configs.use_absolute_path:
98-
output_path = os.path.abspath(identifier)
99-
else:
100-
output_path = os.path.relpath(identifier, configs.project_root)
101-
full_result = {"path": output_path}
102-
with open(identifier) as fin:
103-
document = fin.read()
104-
full_result["document"] = document
133+
assert configs.project_root
105134

106-
structured_result.append(
107-
{str(key): full_result[str(key)] for key in configs.include}
108-
)
109-
elif QueryInclude.chunk in configs.include:
110-
chunks: GetResult = await collection.get(
111-
identifier, include=[IncludeEnum.metadatas, IncludeEnum.documents]
112-
)
113-
meta = chunks.get(
114-
"metadatas",
115-
)
116-
if meta is not None and len(meta) != 0:
117-
chunk_texts = chunks.get("documents")
118-
assert chunk_texts is not None, (
119-
"QueryResult does not contain `documents`!"
120-
)
121-
full_result: dict[str, str | int] = {
122-
"chunk": str(chunk_texts[0]),
123-
"chunk_id": identifier,
124-
}
125-
if meta[0].get("start") is not None and meta[0].get("end") is not None:
126-
path = str(meta[0].get("path"))
127-
with open(path) as fin:
128-
start: int = int(meta[0]["start"])
129-
end: int = int(meta[0]["end"])
130-
full_result["chunk"] = "".join(fin.readlines()[start : end + 1])
131-
full_result["start_line"] = start
132-
full_result["end_line"] = end
133-
if QueryInclude.path in configs.include:
134-
full_result["path"] = str(
135-
meta[0]["path"]
136-
if configs.use_absolute_path
137-
else os.path.relpath(
138-
str(meta[0]["path"]), str(configs.project_root)
139-
)
140-
)
141-
142-
structured_result.append(full_result)
143-
else: # pragma: nocover
144-
logger.error(
145-
"This collection doesn't support chunk-mode output because it lacks the necessary metadata. Please re-vectorise it.",
146-
)
135+
def make_output_path(path: str, absolute: bool) -> str:
136+
if absolute:
137+
if os.path.isabs(path):
138+
return path
139+
return os.path.abspath(os.path.join(str(configs.project_root), path))
140+
else:
141+
rel_path = os.path.relpath(path, configs.project_root)
142+
if isinstance(rel_path, bytes): # pragma: nocover
143+
# for some reasons, some python versions report that `os.path.relpath` returns a string.
144+
rel_path = rel_path.decode()
145+
return rel_path
147146

147+
structured_result = []
148+
for res in await get_query_result_files(collection, configs):
149+
if isinstance(res, str):
150+
output_path = make_output_path(res, configs.use_absolute_path)
151+
io_path = make_output_path(res, True)
152+
if not os.path.isfile(io_path):
153+
logger.warning(f"{io_path} is no longer a valid file.")
154+
continue
155+
with open(io_path) as fin:
156+
structured_result.append({"path": output_path, "document": fin.read()})
148157
else:
149-
logger.warning(
150-
f"{identifier} is no longer a valid file! Please re-run vectorcode vectorise to refresh the database.",
158+
res = cast(Chunk, res)
159+
assert res.path, f"{res} has no `path` attribute."
160+
structured_result.append(
161+
{
162+
"path": make_output_path(res.path, configs.use_absolute_path)
163+
if res.path is not None
164+
else None,
165+
"chunk": res.text,
166+
"start_line": res.start.row if res.start is not None else None,
167+
"end_line": res.end.row if res.end is not None else None,
168+
"chunk_id": res.id,
169+
}
151170
)
152171
for result in structured_result:
153172
if result.get("path") is not None:

0 commit comments

Comments
 (0)