Skip to content

Commit 785b5a4

Browse files
committed
feat: add default trucation in embedder
1 parent 8e27d61 commit 785b5a4

File tree

6 files changed

+107
-0
lines changed

6 files changed

+107
-0
lines changed

src/memos/configs/embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig):
1212
embedding_dims: int | None = Field(
1313
default=None, description="Number of dimensions for the embedding"
1414
)
15+
max_tokens: int | None = Field(
16+
default=8192,
17+
description="Maximum number of tokens per text. Texts exceeding this limit will be automatically truncated. Set to None to disable truncation.",
18+
)
1519
headers_extra: dict[str, Any] | None = Field(
1620
default=None,
1721
description="Extra headers for the embedding model, only for universal_api backend",

src/memos/embedders/ark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def embed(self, texts: list[str]) -> list[list[float]]:
4949
MultimodalEmbeddingContentPartTextParam,
5050
)
5151

52+
# Truncate texts if max_tokens is configured
53+
texts = self._truncate_texts(texts)
54+
5255
if self.config.multi_modal:
5356
texts_input = [
5457
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts

src/memos/embedders/base.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,105 @@
1+
import re
2+
13
from abc import ABC, abstractmethod
24

35
from memos.configs.embedder import BaseEmbedderConfig
46

57

8+
def _count_tokens_for_embedding(text: str) -> int:
9+
"""
10+
Count tokens in text for embedding truncation.
11+
Uses tiktoken if available, otherwise falls back to heuristic.
12+
13+
Args:
14+
text: Text to count tokens for.
15+
16+
Returns:
17+
Number of tokens.
18+
"""
19+
try:
20+
import tiktoken
21+
22+
try:
23+
enc = tiktoken.encoding_for_model("gpt-4o-mini")
24+
except Exception:
25+
enc = tiktoken.get_encoding("cl100k_base")
26+
return len(enc.encode(text or ""))
27+
except Exception:
28+
# Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars
29+
if not text:
30+
return 0
31+
zh_chars = re.findall(r"[\u4e00-\u9fff]", text)
32+
zh = len(zh_chars)
33+
rest = len(text) - zh
34+
return zh + max(1, rest // 4)
35+
36+
37+
def _truncate_text_to_tokens(text: str, max_tokens: int) -> str:
38+
"""
39+
Truncate text to fit within max_tokens limit.
40+
Uses binary search to find the optimal truncation point.
41+
42+
Args:
43+
text: Text to truncate.
44+
max_tokens: Maximum number of tokens allowed.
45+
46+
Returns:
47+
Truncated text.
48+
"""
49+
if not text or max_tokens is None or max_tokens <= 0:
50+
return text
51+
52+
current_tokens = _count_tokens_for_embedding(text)
53+
if current_tokens <= max_tokens:
54+
return text
55+
56+
# Binary search for the right truncation point
57+
low, high = 0, len(text)
58+
best_text = ""
59+
60+
while low < high:
61+
mid = (low + high + 1) // 2 # Use +1 to avoid infinite loop
62+
truncated = text[:mid]
63+
tokens = _count_tokens_for_embedding(truncated)
64+
65+
if tokens <= max_tokens:
66+
best_text = truncated
67+
low = mid
68+
else:
69+
high = mid - 1
70+
71+
return best_text if best_text else text[:1] # Fallback to at least one character
72+
73+
674
class BaseEmbedder(ABC):
775
"""Base class for all Embedding models."""
876

977
@abstractmethod
1078
def __init__(self, config: BaseEmbedderConfig):
1179
"""Initialize the embedding model with the given configuration."""
80+
self.config = config
81+
82+
def _truncate_texts(self, texts: list[str], approx_char_per_token=1.1) -> (list)[str]:
83+
"""
84+
Truncate texts to fit within max_tokens limit if configured.
85+
86+
Args:
87+
texts: List of texts to truncate.
88+
89+
Returns:
90+
List of truncated texts.
91+
"""
92+
if not hasattr(self, "config") or self.config.max_tokens is None:
93+
return texts
94+
max_tokens = self.config.max_tokens
95+
96+
truncated = []
97+
for t in texts:
98+
if len(t) < max_tokens * approx_char_per_token:
99+
truncated.append(t)
100+
else:
101+
truncated.append(_truncate_text_to_tokens(t, max_tokens))
102+
return truncated
12103

13104
@abstractmethod
14105
def embed(self, texts: list[str]) -> list[list[float]]:

src/memos/embedders/ollama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def embed(self, texts: list[str]) -> list[list[float]]:
6767
Returns:
6868
List of embeddings, each represented as a list of floats.
6969
"""
70+
# Truncate texts if max_tokens is configured
71+
texts = self._truncate_texts(texts)
72+
7073
response = self.client.embed(
7174
model=self.config.model_name_or_path,
7275
input=texts,

src/memos/embedders/sentence_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,8 @@ def embed(self, texts: list[str]) -> list[list[float]]:
4242
Returns:
4343
List of embeddings, each represented as a list of floats.
4444
"""
45+
# Truncate texts if max_tokens is configured
46+
texts = self._truncate_texts(texts)
47+
4548
embeddings = self.model.encode(texts, convert_to_numpy=True)
4649
return embeddings.tolist()

src/memos/embedders/universal_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
3636
log_extra_args={"model_name_or_path": "text-embedding-3-large"},
3737
)
3838
def embed(self, texts: list[str]) -> list[list[float]]:
39+
# Truncate texts if max_tokens is configured
40+
texts = self._truncate_texts(texts)
41+
3942
if self.provider == "openai" or self.provider == "azure":
4043
try:
4144
response = self.client.embeddings.create(

0 commit comments

Comments
 (0)