Skip to content

Commit dca23e2

Browse files
authored
Merge pull request #2523 from danielaskdd/embedding-max-token
feat: Add Automatic Text Truncation Support for Embedding Functions
2 parents b31b910 + e2a95ab commit dca23e2

File tree

8 files changed

+152
-27
lines changed

8 files changed

+152
-27
lines changed

docs/OfflineDeployment.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ LightRAG provides flexible dependency groups for different use cases:
7575
### Available Dependency Groups
7676

7777
| Group | Description | Use Case |
78-
|-------|-------------|----------|
78+
| ----- | ----------- | -------- |
7979
| `api` | API server + document processing | FastAPI server with PDF, DOCX, PPTX, XLSX support |
8080
| `offline-storage` | Storage backends | Redis, Neo4j, MongoDB, PostgreSQL, etc. |
8181
| `offline-llm` | LLM providers | OpenAI, Anthropic, Ollama, etc. |
@@ -120,7 +120,7 @@ Tiktoken downloads BPE encoding models on first use. In offline environments, yo
120120
After installing LightRAG, use the built-in command:
121121

122122
```bash
123-
# Download to default location (~/.tiktoken_cache)
123+
# Download to default location (see output for exact path)
124124
lightrag-download-cache
125125

126126
# Download to specific directory

lightrag/api/lightrag_server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,9 @@ async def bedrock_model_complete(
966966
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
967967
)
968968
else:
969-
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
969+
logger.info(
970+
"Embedding max_token_size: None (Embedding token limit is disabled)."
971+
)
970972

971973
# Configure rerank function based on args.rerank_bindingparameter
972974
rerank_model_func = None

lightrag/llm/gemini.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ async def gemini_embed(
476476
base_url: str | None = None,
477477
api_key: str | None = None,
478478
embedding_dim: int | None = None,
479+
max_token_size: int | None = None,
479480
task_type: str = "RETRIEVAL_DOCUMENT",
480481
timeout: int | None = None,
481482
token_tracker: Any | None = None,
@@ -497,6 +498,11 @@ async def gemini_embed(
497498
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator
498499
or the EMBEDDING_DIM environment variable.
499500
Supported range: 128-3072. Recommended values: 768, 1536, 3072.
501+
max_token_size: Maximum tokens per text. This parameter is automatically
502+
injected by the EmbeddingFunc wrapper when the underlying function
503+
signature supports it (via inspect.signature check). Gemini API will
504+
automatically truncate texts exceeding this limit (autoTruncate=True
505+
by default), so no client-side truncation is needed.
500506
task_type: Task type for embedding optimization. Default is "RETRIEVAL_DOCUMENT".
501507
Supported types: SEMANTIC_SIMILARITY, CLASSIFICATION, CLUSTERING,
502508
RETRIEVAL_DOCUMENT, RETRIEVAL_QUERY, CODE_RETRIEVAL_QUERY,
@@ -516,7 +522,11 @@ async def gemini_embed(
516522
- For dimension 3072: Embeddings are already normalized by the API
517523
- For dimensions < 3072: Embeddings are L2-normalized after retrieval
518524
- Normalization ensures accurate semantic similarity via cosine distance
525+
- Gemini API automatically truncates texts exceeding max_token_size (autoTruncate=True)
519526
"""
527+
# Note: max_token_size is received but not used for client-side truncation.
528+
# Gemini API handles truncation automatically with autoTruncate=True (default).
529+
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
520530
loop = asyncio.get_running_loop()
521531

522532
key = _ensure_api_key(api_key)

lightrag/llm/ollama.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,33 @@ async def ollama_model_complete(
176176
embedding_dim=1024, max_token_size=8192, model_name="bge-m3:latest"
177177
)
178178
async def ollama_embed(
179-
texts: list[str], embed_model: str = "bge-m3:latest", **kwargs
179+
texts: list[str],
180+
embed_model: str = "bge-m3:latest",
181+
max_token_size: int | None = None,
182+
**kwargs,
180183
) -> np.ndarray:
184+
"""Generate embeddings using Ollama's API.
185+
186+
Args:
187+
texts: List of texts to embed.
188+
embed_model: The Ollama embedding model to use. Default is "bge-m3:latest".
189+
max_token_size: Maximum tokens per text. This parameter is automatically
190+
injected by the EmbeddingFunc wrapper when the underlying function
191+
signature supports it (via inspect.signature check). Ollama will
192+
automatically truncate texts exceeding the model's context length
193+
(num_ctx), so no client-side truncation is needed.
194+
**kwargs: Additional arguments passed to the Ollama client.
195+
196+
Returns:
197+
A numpy array of embeddings, one per input text.
198+
199+
Note:
200+
- Ollama API automatically truncates texts exceeding the model's context length
201+
- The max_token_size parameter is received but not used for client-side truncation
202+
"""
203+
# Note: max_token_size is received but not used for client-side truncation.
204+
# Ollama API handles truncation automatically based on the model's num_ctx setting.
205+
_ = max_token_size # Acknowledge parameter to avoid unused variable warning
181206
api_key = kwargs.pop("api_key", None)
182207
if not api_key:
183208
api_key = os.getenv("OLLAMA_API_KEY")

lightrag/llm/openai.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import AsyncIterator
66

77
import pipmaster as pm
8+
import tiktoken
89

910
# install specific modules
1011
if not pm.is_installed("openai"):
@@ -74,6 +75,30 @@ class InvalidResponseError(Exception):
7475
pass
7576

7677

78+
# Module-level cache for tiktoken encodings
79+
_TIKTOKEN_ENCODING_CACHE: dict[str, Any] = {}
80+
81+
82+
def _get_tiktoken_encoding_for_model(model: str) -> Any:
83+
"""Get tiktoken encoding for the specified model with caching.
84+
85+
Args:
86+
model: The model name to get encoding for.
87+
88+
Returns:
89+
The tiktoken encoding for the model.
90+
"""
91+
if model not in _TIKTOKEN_ENCODING_CACHE:
92+
try:
93+
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.encoding_for_model(model)
94+
except KeyError:
95+
logger.debug(
96+
f"Encoding for model '{model}' not found, falling back to cl100k_base"
97+
)
98+
_TIKTOKEN_ENCODING_CACHE[model] = tiktoken.get_encoding("cl100k_base")
99+
return _TIKTOKEN_ENCODING_CACHE[model]
100+
101+
77102
def create_openai_async_client(
78103
api_key: str | None = None,
79104
base_url: str | None = None,
@@ -695,15 +720,17 @@ async def openai_embed(
695720
base_url: str | None = None,
696721
api_key: str | None = None,
697722
embedding_dim: int | None = None,
723+
max_token_size: int | None = None,
698724
client_configs: dict[str, Any] | None = None,
699725
token_tracker: Any | None = None,
700726
use_azure: bool = False,
701727
azure_deployment: str | None = None,
702728
api_version: str | None = None,
703729
) -> np.ndarray:
704-
"""Generate embeddings for a list of texts using OpenAI's API.
730+
"""Generate embeddings for a list of texts using OpenAI's API with automatic text truncation.
705731
706-
This function supports both standard OpenAI and Azure OpenAI services.
732+
This function supports both standard OpenAI and Azure OpenAI services. It automatically
733+
truncates texts that exceed the model's token limit to prevent API errors.
707734
708735
Args:
709736
texts: List of texts to embed.
@@ -719,6 +746,10 @@ async def openai_embed(
719746
The dimension is controlled by the @wrap_embedding_func_with_attrs decorator.
720747
Manually passing a different value will trigger a warning and be ignored.
721748
When provided (by EmbeddingFunc), it will be passed to the OpenAI API for dimension reduction.
749+
max_token_size: Maximum tokens per text. Texts exceeding this limit will be truncated.
750+
**IMPORTANT**: This parameter is automatically injected by the EmbeddingFunc wrapper
751+
when the underlying function signature supports it (via inspect.signature check).
752+
The value is controlled by the @wrap_embedding_func_with_attrs decorator.
722753
client_configs: Additional configuration options for the AsyncOpenAI/AsyncAzureOpenAI client.
723754
These will override any default configurations but will be overridden by
724755
explicit parameters (api_key, base_url). Supports proxy configuration,
@@ -740,6 +771,35 @@ async def openai_embed(
740771
RateLimitError: If the OpenAI API rate limit is exceeded.
741772
APITimeoutError: If the OpenAI API request times out.
742773
"""
774+
# Apply text truncation if max_token_size is provided
775+
if max_token_size is not None and max_token_size > 0:
776+
encoding = _get_tiktoken_encoding_for_model(model)
777+
truncated_texts = []
778+
truncation_count = 0
779+
780+
for text in texts:
781+
if not text:
782+
truncated_texts.append(text)
783+
continue
784+
785+
tokens = encoding.encode(text)
786+
if len(tokens) > max_token_size:
787+
truncated_tokens = tokens[:max_token_size]
788+
truncated_texts.append(encoding.decode(truncated_tokens))
789+
truncation_count += 1
790+
logger.debug(
791+
f"Text truncated from {len(tokens)} to {max_token_size} tokens"
792+
)
793+
else:
794+
truncated_texts.append(text)
795+
796+
if truncation_count > 0:
797+
logger.info(
798+
f"Truncated {truncation_count}/{len(texts)} texts to fit token limit ({max_token_size})"
799+
)
800+
801+
texts = truncated_texts
802+
743803
# Create the OpenAI client (supports both OpenAI and Azure)
744804
openai_async_client = create_openai_async_client(
745805
api_key=api_key,

lightrag/operate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,12 @@ async def _summarize_descriptions(
365365
if embedding_token_limit is not None and summary:
366366
tokenizer = global_config["tokenizer"]
367367
summary_token_count = len(tokenizer.encode(summary))
368-
threshold = int(embedding_token_limit * 0.9)
368+
threshold = int(embedding_token_limit)
369369

370370
if summary_token_count > threshold:
371371
logger.warning(
372-
f"Summary tokens ({summary_token_count}) exceeds 90% of embedding limit "
373-
f"({embedding_token_limit}) for {description_type}: {description_name}"
372+
f"Summary tokens({summary_token_count}) exceeds embedding_token_limit({embedding_token_limit}) "
373+
f" for {description_type}: {description_name}"
374374
)
375375

376376
return summary

lightrag/tools/download_cache.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,35 +10,51 @@
1010
from pathlib import Path
1111

1212

13+
# Known tiktoken encoding names (not model names)
14+
# These need to be loaded with tiktoken.get_encoding() instead of tiktoken.encoding_for_model()
15+
TIKTOKEN_ENCODING_NAMES = {"cl100k_base", "p50k_base", "r50k_base", "o200k_base"}
16+
17+
1318
def download_tiktoken_cache(cache_dir: str = None, models: list = None):
1419
"""Download tiktoken models to local cache
1520
1621
Args:
17-
cache_dir: Directory to store the cache files. If None, uses default location.
18-
models: List of model names to download. If None, downloads common models.
22+
cache_dir: Directory to store the cache files. If None, uses tiktoken's default location.
23+
models: List of model names or encoding names to download. If None, downloads common ones.
1924
2025
Returns:
21-
Tuple of (success_count, failed_models)
26+
Tuple of (success_count, failed_models, actual_cache_dir)
2227
"""
23-
try:
24-
import tiktoken
25-
except ImportError:
26-
print("Error: tiktoken is not installed.")
27-
print("Install with: pip install tiktoken")
28-
sys.exit(1)
28+
# If user specified a cache directory, set it BEFORE importing tiktoken
29+
# tiktoken reads TIKTOKEN_CACHE_DIR at import time
30+
user_specified_cache = cache_dir is not None
2931

30-
# Set cache directory if provided
31-
if cache_dir:
32+
if user_specified_cache:
3233
cache_dir = os.path.abspath(cache_dir)
3334
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir
3435
cache_path = Path(cache_dir)
3536
cache_path.mkdir(parents=True, exist_ok=True)
36-
print(f"Using cache directory: {cache_dir}")
37+
print(f"Using specified cache directory: {cache_dir}")
3738
else:
38-
cache_dir = os.environ.get(
39-
"TIKTOKEN_CACHE_DIR", str(Path.home() / ".tiktoken_cache")
40-
)
41-
print(f"Using default cache directory: {cache_dir}")
39+
# Check if TIKTOKEN_CACHE_DIR is already set in environment
40+
env_cache_dir = os.environ.get("TIKTOKEN_CACHE_DIR")
41+
if env_cache_dir:
42+
cache_dir = env_cache_dir
43+
print(f"Using TIKTOKEN_CACHE_DIR from environment: {cache_dir}")
44+
else:
45+
# Use tiktoken's default location (tempdir/data-gym-cache)
46+
import tempfile
47+
48+
cache_dir = os.path.join(tempfile.gettempdir(), "data-gym-cache")
49+
print(f"Using tiktoken default cache directory: {cache_dir}")
50+
51+
# Now import tiktoken (it will use the cache directory we determined)
52+
try:
53+
import tiktoken
54+
except ImportError:
55+
print("Error: tiktoken is not installed.")
56+
print("Install with: pip install tiktoken")
57+
sys.exit(1)
4258

4359
# Common models used by LightRAG and OpenAI
4460
if models is None:
@@ -50,6 +66,7 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
5066
"text-embedding-ada-002", # Legacy embedding model
5167
"text-embedding-3-small", # Small embedding model
5268
"text-embedding-3-large", # Large embedding model
69+
"cl100k_base", # Default encoding for LightRAG
5370
]
5471

5572
print(f"\nDownloading {len(models)} tiktoken models...")
@@ -61,13 +78,17 @@ def download_tiktoken_cache(cache_dir: str = None, models: list = None):
6178
for i, model in enumerate(models, 1):
6279
try:
6380
print(f"[{i}/{len(models)}] Downloading {model}...", end=" ", flush=True)
64-
encoding = tiktoken.encoding_for_model(model)
81+
# Use get_encoding for encoding names, encoding_for_model for model names
82+
if model in TIKTOKEN_ENCODING_NAMES:
83+
encoding = tiktoken.get_encoding(model)
84+
else:
85+
encoding = tiktoken.encoding_for_model(model)
6586
# Trigger download by encoding a test string
6687
encoding.encode("test")
6788
print("✓ Done")
6889
success_count += 1
6990
except KeyError as e:
70-
print(f"✗ Failed: Unknown model '{model}'")
91+
print(f"✗ Failed: Unknown model or encoding '{model}'")
7192
failed_models.append((model, str(e)))
7293
except Exception as e:
7394
print(f"✗ Failed: {e}")

lightrag/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
import html
88
import csv
9+
import inspect
910
import json
1011
import logging
1112
import logging.handlers
@@ -492,6 +493,12 @@ async def __call__(self, *args, **kwargs) -> np.ndarray:
492493
# Inject embedding_dim from decorator
493494
kwargs["embedding_dim"] = self.embedding_dim
494495

496+
# Check if underlying function supports max_token_size and inject if not provided
497+
if self.max_token_size is not None and "max_token_size" not in kwargs:
498+
sig = inspect.signature(self.func)
499+
if "max_token_size" in sig.parameters:
500+
kwargs["max_token_size"] = self.max_token_size
501+
495502
# Call the actual embedding function
496503
result = await self.func(*args, **kwargs)
497504

0 commit comments

Comments
 (0)