Skip to content

Commit ca4f347

Browse files
committed
feat(cli)!: implement treesitter-based chunking.
1 parent 3b69e92 commit ca4f347

File tree

6 files changed

+158
-23
lines changed

6 files changed

+158
-23
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,12 @@ and chat plugin available on VSCode and JetBrain products.
5656

5757
## TODOs
5858
- [x] query by ~file path~ excluded paths;
59-
- [ ] chunking support;
59+
- [x] chunking support;
6060
- [x] add metadata for files;
6161
- [x] chunk-size configuration;
62-
- [ ] smarter chunking (semantics/syntax based);
62+
- [x] smarter chunking (semantics/syntax based), implemented with
63+
[py-tree-sitter](https://github.com/tree-sitter/py-tree-sitter) and
64+
[tree-sitter-language-pack](https://github.com/Goldziher/tree-sitter-language-pack);
6365
- [x] configurable document selection from query results.
6466
- [x] ~NeoVim Lua API with cache to skip the retrieval when a project has not
6567
been indexed~ Returns empty array instead;

docs/cli.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,15 @@ The JSON configuration file may hold the following values:
187187
- `overlap_ratio`: float between 0 and 1, the ratio of overlapping/shared content
188188
between 2 adjacent chunks. A larger ratio improves the coherences of chunks,
189189
but at the cost of increasing number of entries in the database and hence
190-
slowing down the search. Default: `0.2`;
190+
slowing down the search. Default: `0.2`. _Starting from 0.4.11, VectorCode
191+
will use treesitter to parse languages that it can automatically detect. It
192+
uses [pygments](https://github.com/pygments/pygments) to guess the language
193+
from filename, and
194+
[tree-sitter-language-pack](https://github.com/Goldziher/tree-sitter-language-pack)
195+
to fetch the correct parser. `overlap_ratio` has no effects when treesitter
196+
works. If VectorCode fails to find an appropriate parser, it'll fallback to
197+
the legacy naive parser, in which case `overlap_ratio` works exactly in the
198+
same way as before;_
191199
- `query_multplier`: integer, when you use the `query` command to retrieve `n` documents,
192200
VectorCode will check `n * query_multplier` chunks and return at most `n`
193201
documents. A larger value of `query_multplier`

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ dependencies = [
1212
"numpy",
1313
"psutil",
1414
"httpx",
15+
"tree-sitter",
16+
"tree-sitter-language-pack",
17+
"pygments",
1518
]
1619
requires-python = ">=3.11,<3.14"
1720
readme = "README.md"

src/vectorcode/chunking.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1+
import os
12
from abc import abstractmethod
3+
from functools import cache
24
from io import TextIOWrapper
3-
from typing import Generator
5+
from typing import Generator, Optional
6+
7+
from pygments.lexer import Lexer
8+
from pygments.lexers import guess_lexer_for_filename
9+
from pygments.util import ClassNotFound
10+
from tree_sitter import Node
11+
from tree_sitter_language_pack import get_parser
412

513

614
class ChunkerBase:
@@ -59,3 +67,68 @@ def chunk(self, data: TextIOWrapper) -> Generator[str, None, None]:
5967
yield output
6068
if len(new_chars) < step_size:
6169
return
70+
71+
72+
class TreeSitterChunker(ChunkerBase):
73+
def __init__(self, chunk_size: int = -1, overlap_ratio: float = 0.2):
74+
super().__init__()
75+
assert isinstance(chunk_size, int), "chunk_size parameter must be an integer"
76+
assert 0 <= overlap_ratio < 1, (
77+
"Overlap ratio has to be a float between 0 (inclusive) and 1 (exclusive)."
78+
)
79+
self.__chunk_size = chunk_size
80+
self.__overlap_ratio = overlap_ratio
81+
82+
def __chunk_node(self, node: Node, text: str) -> Generator[str, None, None]:
83+
current_chunk = ""
84+
for child in node.children:
85+
child_length = child.end_byte - child.start_byte
86+
if child_length > self.__chunk_size:
87+
if current_chunk:
88+
yield current_chunk
89+
current_chunk = ""
90+
yield from self.__chunk_node(child, text)
91+
elif len(current_chunk) + child_length > self.__chunk_size:
92+
yield current_chunk
93+
current_chunk = text[child.start_byte : child.end_byte]
94+
else:
95+
current_chunk += text[child.start_byte : child.end_byte]
96+
if current_chunk:
97+
yield current_chunk
98+
99+
@cache
100+
def __guess_type(self, path: str, content: str) -> Optional[Lexer]:
101+
try:
102+
return guess_lexer_for_filename(path, content)
103+
104+
except ClassNotFound:
105+
return None
106+
107+
def chunk(self, data: str) -> Generator[str, None, None]:
108+
"""
109+
data: path to the file
110+
"""
111+
assert os.path.isfile(data)
112+
with open(data) as fin:
113+
content = fin.read()
114+
parser = None
115+
lexer = self.__guess_type(data, content)
116+
if lexer is not None:
117+
lang_names = [lexer.name]
118+
lang_names.extend(lexer.aliases)
119+
for name in lang_names:
120+
try:
121+
parser = get_parser(name.lower())
122+
break
123+
except LookupError:
124+
pass
125+
126+
if parser is None:
127+
# fall back to naive chunking
128+
yield from StringChunker(self.__chunk_size, self.__overlap_ratio).chunk(
129+
content
130+
)
131+
else:
132+
content_bytes = content.encode()
133+
tree = parser.parse(content_bytes)
134+
yield from self.__chunk_node(tree.root_node, content)

src/vectorcode/subcommands/vectorise.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from chromadb.api.models.AsyncCollection import AsyncCollection
1414
from chromadb.api.types import IncludeEnum
1515

16-
from vectorcode.chunking import FileChunker
16+
from vectorcode.chunking import TreeSitterChunker
1717
from vectorcode.cli_utils import Config, expand_globs, expand_path
1818
from vectorcode.common import get_client, get_collection, verify_ef
1919

@@ -54,24 +54,23 @@ async def chunked_add(
5454

5555
try:
5656
async with semaphore:
57-
with open(full_path_str) as fin:
58-
chunks = list(
59-
FileChunker(configs.chunk_size, configs.overlap_ratio).chunk(fin)
57+
chunks = list(
58+
TreeSitterChunker(configs.chunk_size, configs.overlap_ratio).chunk(
59+
full_path_str
6060
)
61-
if len(chunks) == 0 or (len(chunks) == 1 and chunks[0] == ""):
62-
# empty file
63-
return
64-
chunks.append(str(os.path.relpath(full_path_str, configs.project_root)))
65-
async with collection_lock:
66-
for idx in range(0, len(chunks), max_batch_size):
67-
inserted_chunks = chunks[idx : idx + max_batch_size]
68-
await collection.add(
69-
ids=[get_uuid() for _ in inserted_chunks],
70-
documents=inserted_chunks,
71-
metadatas=[
72-
{"path": full_path_str} for _ in inserted_chunks
73-
],
74-
)
61+
)
62+
if len(chunks) == 0 or (len(chunks) == 1 and chunks[0] == ""):
63+
# empty file
64+
return
65+
chunks.append(str(os.path.relpath(full_path_str, configs.project_root)))
66+
async with collection_lock:
67+
for idx in range(0, len(chunks), max_batch_size):
68+
inserted_chunks = chunks[idx : idx + max_batch_size]
69+
await collection.add(
70+
ids=[get_uuid() for _ in inserted_chunks],
71+
documents=inserted_chunks,
72+
metadatas=[{"path": full_path_str} for _ in inserted_chunks],
73+
)
7574
except UnicodeDecodeError:
7675
# probably binary. skip it.
7776
return

tests/test_chunking.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from vectorcode.chunking import FileChunker, StringChunker
1+
import os
2+
import tempfile
3+
4+
from vectorcode.chunking import FileChunker, StringChunker, TreeSitterChunker
25

36

47
class TestChunking:
@@ -46,3 +49,50 @@ def test_file_chunker(self):
4649
)
4750
for string_chunk, file_chunk in zip(string_chunks, file_chunks):
4851
assert string_chunk == file_chunk
52+
53+
54+
def test_treesitter_chunker():
55+
"""Test TreeSitterChunker with a sample file using tempfile."""
56+
chunker = TreeSitterChunker(chunk_size=30)
57+
test_content = r"""
58+
def foo():
59+
return "foo"
60+
61+
def bar():
62+
return "bar"
63+
"""
64+
65+
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".py") as tmp_file:
66+
tmp_file.write(test_content)
67+
test_file = tmp_file.name
68+
69+
chunks = list(chunker.chunk(test_file))
70+
assert len(chunks) == 2
71+
assert all(len(i) <= 30 for i in chunks)
72+
73+
os.remove(test_file)
74+
75+
76+
def test_treesitter_chunker_fallback():
77+
"""Test that TreeSitterChunker falls back to StringChunker when no parser is found."""
78+
chunk_size = 30
79+
overlap_ratio = 0.2
80+
tree_sitter_chunker = TreeSitterChunker(
81+
chunk_size=chunk_size, overlap_ratio=overlap_ratio
82+
)
83+
string_chunker = StringChunker(chunk_size=chunk_size, overlap_ratio=overlap_ratio)
84+
85+
test_content = "This is a test string."
86+
87+
with tempfile.NamedTemporaryFile(
88+
mode="w", delete=False, suffix=".xyz"
89+
) as tmp_file: # Use an uncommon extension
90+
tmp_file.write(test_content)
91+
test_file = tmp_file.name
92+
93+
tree_sitter_chunks = list(tree_sitter_chunker.chunk(test_file))
94+
string_chunks = list(string_chunker.chunk(test_content))
95+
96+
assert tree_sitter_chunks == string_chunks
97+
98+
os.remove(test_file)

0 commit comments

Comments
 (0)