Skip to content

Commit f997701

Browse files
authored
feat(cli): Allow truncating embeddings to specified dimensions. (#265)
* feat(cli): Truncate embeddings to specified dimensions * fix(cli): Truncate query embeddings to specified dimensions * tests(cli): Add tests for truncated embeddings * docs(cli): Document embedding_dims configuration option * Auto generate docs --------- Co-authored-by: Davidyz <[email protected]>
1 parent e6d3f6b commit f997701

File tree

7 files changed

+82
-21
lines changed

7 files changed

+82
-21
lines changed

doc/VectorCode-cli.txt

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -322,17 +322,23 @@ embedding function takes. For `OllamaEmbeddingFunction`, if you set
322322
"model_name": "nomic-embed-text" }` Then the embedding function object will be
323323
initialised as
324324
`OllamaEmbeddingFunction(url="http://127.0.0.1:11434/api/embeddings",
325-
model_name="nomic-embed-text")`. Default: `{}`; - `db_url`string, the url that
326-
points to the Chromadb server. VectorCode will start an HTTP server for
327-
Chromadb at a randomly picked free port on `localhost` if your configured
328-
`http://host:port` is not accessible. Default: `http://127.0.0.1:8000`; -
329-
`db_path`string, Path to local persistent database. If you didn’t set up a
330-
standalone Chromadb server, this is where the files for your database will be
331-
stored. Default: `~/.local/share/vectorcode/chromadb/`; - `db_log_path`string,
332-
path to the _directory_ where the built-in chromadb server will write the log
333-
to. Default: `~/.local/share/vectorcode/`; - `chunk_size`integer, the maximum
334-
number of characters per chunk. A larger value reduces the number of items in
335-
the database, and hence accelerates the search, but at the cost of potentially
325+
model_name="nomic-embed-text")`. Default: `{}`; - `embedding_dims`integer or
326+
`null`, the number of dimensions to truncate the embeddings to. _Make sure your
327+
model supports Matryoshka Representation Learning (MRL) before using this._
328+
Learn more about MRL here
329+
<https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings>.
330+
When set to `null` (or unset), the embeddings won’t be truncated; -
331+
`db_url`string, the url that points to the Chromadb server. VectorCode will
332+
start an HTTP server for Chromadb at a randomly picked free port on `localhost`
333+
if your configured `http://host:port` is not accessible. Default:
334+
`http://127.0.0.1:8000`; - `db_path`string, Path to local persistent database.
335+
If you didn’t set up a standalone Chromadb server, this is where the files
336+
for your database will be stored. Default:
337+
`~/.local/share/vectorcode/chromadb/`; - `db_log_path`string, path to the
338+
_directory_ where the built-in chromadb server will write the log to. Default:
339+
`~/.local/share/vectorcode/`; - `chunk_size`integer, the maximum number of
340+
characters per chunk. A larger value reduces the number of items in the
341+
database, and hence accelerates the search, but at the cost of potentially
336342
truncated data and lost information. Default: `2500`. To disable chunking, set
337343
it to a negative number; - `overlap_ratio`float between 0 and 1, the ratio of
338344
overlapping/shared content between 2 adjacent chunks. A larger ratio improves

docs/cli.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,10 @@ The JSON configuration file may hold the following values:
275275
Then the embedding function object will be initialised as
276276
`OllamaEmbeddingFunction(url="http://127.0.0.1:11434/api/embeddings",
277277
model_name="nomic-embed-text")`. Default: `{}`;
278+
- `embedding_dims`: integer or `null`, the number of dimensions to truncate the embeddings
279+
to. _Make sure your model supports Matryoshka Representation Learning (MRL)
280+
before using this._ Learn more about MRL [here](https://sbert.net/examples/sentence_transformer/training/matryoshka/README.html#matryoshka-embeddings).
281+
When set to `null` (or unset), the embeddings won't be truncated;
278282
- `db_url`: string, the url that points to the Chromadb server. VectorCode will start an
279283
HTTP server for Chromadb at a randomly picked free port on `localhost` if your
280284
configured `http://host:port` is not accessible. Default: `http://127.0.0.1:8000`;

src/vectorcode/cli_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class Config:
8989
db_url: str = "http://127.0.0.1:8000"
9090
embedding_function: str = "SentenceTransformerEmbeddingFunction" # This should fallback to whatever the default is.
9191
embedding_params: dict[str, Any] = field(default_factory=(lambda: {}))
92+
embedding_dims: Optional[int] = None
9293
n_result: int = 1
9394
force: bool = False
9495
db_path: Optional[str] = "~/.local/share/vectorcode/chromadb/"
@@ -139,6 +140,9 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
139140
"embedding_params": config_dict.get(
140141
"embedding_params", default_config.embedding_params
141142
),
143+
"embedding_dims": config_dict.get(
144+
"embedding_dims", default_config.embedding_dims
145+
),
142146
"db_url": config_dict.get("db_url", default_config.db_url),
143147
"db_path": db_path,
144148
"db_log_path": os.path.expanduser(

src/vectorcode/subcommands/query/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ async def get_query_result_files(
6767
await collection.count(),
6868
)
6969
logger.info(f"Querying {num_query} chunks for reranking.")
70+
query_embeddings = get_embedding_function(configs)(query_chunks)
71+
if isinstance(configs.embedding_dims, int) and configs.embedding_dims > 0:
72+
query_embeddings = [e[: configs.embedding_dims] for e in query_embeddings]
7073
results = await collection.query(
71-
query_embeddings=get_embedding_function(configs)(query_chunks),
74+
query_embeddings=query_embeddings,
7275
n_results=num_query,
7376
include=[
7477
IncludeEnum.metadatas,

src/vectorcode/subcommands/vectorise.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,12 +146,21 @@ async def chunked_add(
146146
async with collection_lock:
147147
for idx in range(0, len(chunks), max_batch_size):
148148
inserted_chunks = chunks[idx : idx + max_batch_size]
149+
embeddings = embedding_function(
150+
list(str(c) for c in inserted_chunks)
151+
)
152+
if (
153+
isinstance(configs.embedding_dims, int)
154+
and configs.embedding_dims > 0
155+
):
156+
logger.debug(
157+
f"Truncating embeddings to {configs.embedding_dims} dimensions."
158+
)
159+
embeddings = [e[: configs.embedding_dims] for e in embeddings]
149160
await collection.add(
150161
ids=[get_uuid() for _ in inserted_chunks],
151162
documents=[str(i) for i in inserted_chunks],
152-
embeddings=embedding_function(
153-
list(str(c) for c in inserted_chunks)
154-
),
163+
embeddings=embeddings,
155164
metadatas=metas,
156165
)
157166
except (UnicodeDecodeError, UnicodeError): # pragma: nocover

tests/subcommands/query/test_query.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,11 @@ async def test_get_query_result_files_chunking(mock_collection, mock_config):
327327
async def test_get_query_result_files_multiple_queries(mock_collection, mock_config):
328328
# Set multiple query terms
329329
mock_config.query = ["term1", "term2", "term3"]
330-
mock_embedding_function = MagicMock()
330+
mock_config.embedding_dims = 10
331+
331332
with (
332333
patch("vectorcode.subcommands.query.StringChunker") as MockChunker,
333334
patch("vectorcode.subcommands.query.reranker.NaiveReranker") as MockReranker,
334-
patch(
335-
"vectorcode.subcommands.query.get_embedding_function",
336-
return_value=mock_embedding_function,
337-
),
338335
):
339336
# Set up MockChunker to return the query terms as is
340337
mock_chunker_instance = MagicMock()
@@ -354,7 +351,7 @@ async def test_get_query_result_files_multiple_queries(mock_collection, mock_con
354351
# Check query was called with all query terms
355352
mock_collection.query.assert_called_once()
356353
_, kwargs = mock_collection.query.call_args
357-
mock_embedding_function.assert_called_once_with(["term1", "term2", "term3"])
354+
assert all(len(i) == 10 for i in kwargs["query_embeddings"])
358355

359356
# Check the result
360357
assert result == ["file1.py", "file2.py"]

tests/subcommands/test_vectorise.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,44 @@ async def test_chunked_add():
103103
assert collection.add.call_count == 1
104104

105105

106+
@pytest.mark.asyncio
107+
async def test_chunked_add_truncated():
108+
file_path = "test_file.py"
109+
collection = AsyncMock()
110+
collection_lock = asyncio.Lock()
111+
stats = VectoriseStats()
112+
stats_lock = asyncio.Lock()
113+
configs = Config(
114+
chunk_size=100, overlap_ratio=0.2, project_root=".", embedding_dims=10
115+
)
116+
max_batch_size = 50
117+
semaphore = asyncio.Semaphore(1)
118+
119+
with (
120+
patch("vectorcode.chunking.TreeSitterChunker.chunk") as mock_chunk,
121+
patch("vectorcode.subcommands.vectorise.hash_file") as mock_hash_file,
122+
):
123+
mock_hash_file.return_value = "hash1"
124+
mock_chunk.return_value = [Chunk("chunk1", Point(1, 0), Point(1, 5)), "chunk2"]
125+
await chunked_add(
126+
file_path,
127+
collection,
128+
collection_lock,
129+
stats,
130+
stats_lock,
131+
configs,
132+
max_batch_size,
133+
semaphore,
134+
)
135+
136+
assert stats.add == 1
137+
assert stats.update == 0
138+
collection.add.assert_called()
139+
assert collection.add.call_count == 1
140+
141+
assert all(len(i) == 10 for i in collection.add.call_args.kwargs["embeddings"])
142+
143+
106144
@pytest.mark.asyncio
107145
async def test_chunked_add_with_existing():
108146
file_path = "test_file.py"

0 commit comments

Comments
 (0)