Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flowsettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
KH_RERANKINGS["voyageai"] = {
"spec": {
"__type__": "kotaemon.rerankings.VoyageAIReranking",
"model_name": "rerank-2",
"model_name": "rerank-2.5",
"api_key": VOYAGE_API_KEY,
},
"default": False,
Expand Down
202 changes: 197 additions & 5 deletions libs/kotaemon/kotaemon/embeddings/voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,42 @@
"""

import importlib
from typing import Generator, Literal, Optional

from kotaemon.base import Document, DocumentWithEmbedding, Param

from .base import BaseEmbeddings

vo = None

# Token limits per batch for each VoyageAI model
# See: https://docs.voyageai.com/docs/embeddings
VOYAGE_TOKEN_LIMITS = {
# voyage-4 family
"voyage-4": 320_000,
"voyage-4-lite": 1_000_000,
"voyage-4-large": 120_000,
# voyage-3 family
"voyage-3": 120_000,
"voyage-3-lite": 120_000,
"voyage-3-large": 120_000,
"voyage-3.5": 320_000,
"voyage-3.5-lite": 1_000_000,
# Specialized models
"voyage-code-3": 120_000,
"voyage-finance-2": 120_000,
"voyage-law-2": 120_000,
"voyage-multilingual-2": 120_000,
"voyage-large-2": 120_000,
"voyage-large-2-instruct": 120_000,
"voyage-code-2": 120_000,
# Context models (use contextualized_embed API)
"voyage-context-3": 32_000,
}

# Default token limit for unknown models
DEFAULT_TOKEN_LIMIT = 120_000


def _import_voyageai():
global vo
Expand All @@ -30,7 +59,10 @@ def _format_output(texts: list[str], embeddings: list[list]):


class VoyageAIEmbeddings(BaseEmbeddings):
"""Voyage AI provides best-in-class embedding models and rerankers."""
"""Voyage AI provides best-in-class embedding models and rerankers.

Supports token-aware batching to optimize API calls within model limits.
"""

api_key: str = Param(None, help="Voyage API key", required=False)
model: str = Param(
Expand All @@ -42,6 +74,24 @@ class VoyageAIEmbeddings(BaseEmbeddings):
),
required=True,
)
batch_size: int = Param(
128,
help=(
"Maximum number of texts per batch. "
"Will be further limited by token count."
),
)
truncation: bool = Param(
True,
help="Whether to truncate texts that exceed the model's max token limit.",
)
output_dimension: Optional[Literal[256, 512, 1024, 2048]] = Param(
None,
help=(
"Output embedding dimension. Only supported by voyage-4 family models. "
"If None, uses the model's default (1024 for voyage-4 models)."
),
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -51,16 +101,158 @@ def __init__(self, *args, **kwargs):
self._client = _import_voyageai().Client(api_key=self.api_key)
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)

def _get_token_limit(self) -> int:
"""Get the token limit for the current model."""
return VOYAGE_TOKEN_LIMITS.get(self.model, DEFAULT_TOKEN_LIMIT)

def _is_context_model(self) -> bool:
"""Check if the model is a contextualized embedding model."""
return "context" in self.model

def _build_batches(
self, texts: list[str]
) -> Generator[tuple[list[str], list[int]], None, None]:
"""Generate batches of texts respecting token limits.

Yields:
Tuple of (batch_texts, original_indices) for each batch
"""
max_tokens = self._get_token_limit()
index = 0

while index < len(texts):
batch: list[str] = []
batch_indices: list[int] = []
batch_tokens = 0

while index < len(texts) and len(batch) < self.batch_size:
# Tokenize the current text to get its token count
token_count = len(
self._client.tokenize([texts[index]], model=self.model)[0]
)

# Check if adding this text would exceed the token limit
if batch_tokens + token_count > max_tokens and len(batch) > 0:
# Yield current batch and start a new one
break

batch_tokens += token_count
batch.append(texts[index])
batch_indices.append(index)
index += 1

if batch:
yield batch, batch_indices

def _embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Embed a single batch of texts."""
if self._is_context_model():
return self._embed_context_batch(texts)
return self._embed_regular_batch(texts)

def _embed_regular_batch(self, texts: list[str]) -> list[list[float]]:
"""Embed using regular embedding API."""
kwargs = {
"model": self.model,
"truncation": self.truncation,
}
if self.output_dimension is not None:
kwargs["output_dimension"] = self.output_dimension

return self._client.embed(texts, **kwargs).embeddings

def _embed_context_batch(self, texts: list[str]) -> list[list[float]]:
"""Embed using contextualized embedding API (for voyage-context-3)."""
if self.output_dimension is not None:
result = self._client.contextualized_embed(
inputs=[texts],
model=self.model,
output_dimension=self.output_dimension,
)
else:
result = self._client.contextualized_embed(
inputs=[texts],
model=self.model,
)
return result.results[0].embeddings

async def _aembed_batch(self, texts: list[str]) -> list[list[float]]:
"""Async embed a single batch of texts."""
if self._is_context_model():
return await self._aembed_context_batch(texts)
return await self._aembed_regular_batch(texts)

async def _aembed_regular_batch(self, texts: list[str]) -> list[list[float]]:
"""Async embed using regular embedding API."""
kwargs = {
"model": self.model,
"truncation": self.truncation,
}
if self.output_dimension is not None:
kwargs["output_dimension"] = self.output_dimension

result = await self._aclient.embed(texts, **kwargs)
return result.embeddings

async def _aembed_context_batch(self, texts: list[str]) -> list[list[float]]:
"""Async embed using contextualized embedding API."""
if self.output_dimension is not None:
result = await self._aclient.contextualized_embed(
inputs=[texts],
model=self.model,
output_dimension=self.output_dimension,
)
else:
result = await self._aclient.contextualized_embed(
inputs=[texts],
model=self.model,
)
return result.results[0].embeddings

def invoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
texts = [t.content for t in self.prepare_input(text)]
embeddings = self._client.embed(texts, model=self.model).embeddings
return _format_output(texts, embeddings)

# For small inputs, skip batching overhead
if len(texts) <= self.batch_size:
token_count = sum(
len(tokens) for tokens in self._client.tokenize(texts, model=self.model)
)
if token_count <= self._get_token_limit():
embeddings = self._embed_batch(texts)
return _format_output(texts, embeddings)

# Use token-aware batching for larger inputs
all_embeddings: list[list[float]] = [[] for _ in range(len(texts))]

for batch_texts, batch_indices in self._build_batches(texts):
batch_embeddings = self._embed_batch(batch_texts)
for idx, embedding in zip(batch_indices, batch_embeddings):
all_embeddings[idx] = embedding

return _format_output(texts, all_embeddings)

async def ainvoke(
self, text: str | list[str] | Document | list[Document], *args, **kwargs
) -> list[DocumentWithEmbedding]:
texts = [t.content for t in self.prepare_input(text)]
embeddings = await self._aclient.embed(texts, model=self.model).embeddings
return _format_output(texts, embeddings)

# For small inputs, skip batching overhead
if len(texts) <= self.batch_size:
token_count = sum(
len(tokens) for tokens in self._client.tokenize(texts, model=self.model)
)
if token_count <= self._get_token_limit():
embeddings = await self._aembed_batch(texts)
return _format_output(texts, embeddings)

# Use token-aware batching for larger inputs
all_embeddings: list[list[float]] = [[] for _ in range(len(texts))]

for batch_texts, batch_indices in self._build_batches(texts):
batch_embeddings = await self._aembed_batch(batch_texts)
for idx, embedding in zip(batch_indices, batch_embeddings):
all_embeddings[idx] = embedding

return _format_output(texts, all_embeddings)
68 changes: 60 additions & 8 deletions libs/kotaemon/kotaemon/rerankings/voyageai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import importlib
from typing import Optional

from decouple import config

Expand All @@ -19,13 +20,21 @@ def _import_voyageai():


class VoyageAIReranking(BaseReranking):
"""VoyageAI Reranking model"""
"""VoyageAI Reranking model.

Supports all VoyageAI reranker models including:
- rerank-2.5: Latest flagship model with instruction-following (recommended)
- rerank-2.5-lite: Cost-effective version with instruction-following
- rerank-2: Previous generation model
- rerank-2-lite: Previous generation lite model
"""

model_name: str = Param(
"rerank-2",
"rerank-2.5",
help=(
"ID of the model to use. You can go to [Supported Models]"
"(https://docs.voyageai.com/docs/reranker) to see the supported models"
"ID of the model to use. Recommended: rerank-2.5 (best quality) or "
"rerank-2.5-lite (cost-effective). See [Supported Models]"
"(https://docs.voyageai.com/docs/reranker) for all options."
),
required=True,
)
Expand All @@ -34,11 +43,19 @@ class VoyageAIReranking(BaseReranking):
help="VoyageAI API key",
required=True,
)
top_k: Optional[int] = Param(
None,
help="Number of top documents to return. If None, returns all documents.",
)
truncation: bool = Param(
True,
help="Whether to truncate documents that exceed the model's context length.",
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.api_key:
raise ValueError("API key must be provided for VoyageAIEmbeddings.")
raise ValueError("API key must be provided for VoyageAIReranking.")

self._client = _import_voyageai().Client(api_key=self.api_key)
self._aclient = _import_voyageai().AsyncClient(api_key=self.api_key)
Expand All @@ -52,9 +69,44 @@ def run(self, documents: list[Document], query: str) -> list[Document]:
return compressed_docs

_docs = [d.content for d in documents]
response = self._client.rerank(
model=self.model_name, query=query, documents=_docs
)

# Build rerank kwargs
rerank_kwargs = {
"model": self.model_name,
"query": query,
"documents": _docs,
"truncation": self.truncation,
}
if self.top_k is not None:
rerank_kwargs["top_k"] = self.top_k

response = self._client.rerank(**rerank_kwargs)
for r in response.results:
doc = documents[r.index]
doc.metadata["reranking_score"] = r.relevance_score
compressed_docs.append(doc)

return compressed_docs

async def arun(self, documents: list[Document], query: str) -> list[Document]:
"""Async version of reranking."""
compressed_docs: list[Document] = []

if not documents:
return compressed_docs

_docs = [d.content for d in documents]

rerank_kwargs = {
"model": self.model_name,
"query": query,
"documents": _docs,
"truncation": self.truncation,
}
if self.top_k is not None:
rerank_kwargs["top_k"] = self.top_k

response = await self._aclient.rerank(**rerank_kwargs)
for r in response.results:
doc = documents[r.index]
doc.metadata["reranking_score"] = r.relevance_score
Expand Down
Loading
Loading