Skip to content

Commit 9febb1d

Browse files
authored
feat: split chunk for pure string (#593)
* feat: split chunk for pure string * feat: add default trucation in embedder * feat: chunking each item after fast mode * fix: test
1 parent 52dfe47 commit 9febb1d

File tree

12 files changed

+253
-68
lines changed

12 files changed

+253
-68
lines changed

src/memos/configs/embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig):
1212
embedding_dims: int | None = Field(
1313
default=None, description="Number of dimensions for the embedding"
1414
)
15+
max_tokens: int | None = Field(
16+
default=8192,
17+
description="Maximum number of tokens per text. Texts exceeding this limit will be automatically truncated. Set to None to disable truncation.",
18+
)
1519
headers_extra: dict[str, Any] | None = Field(
1620
default=None,
1721
description="Extra headers for the embedding model, only for universal_api backend",

src/memos/embedders/ark.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def embed(self, texts: list[str]) -> list[list[float]]:
4949
MultimodalEmbeddingContentPartTextParam,
5050
)
5151

52+
# Truncate texts if max_tokens is configured
53+
texts = self._truncate_texts(texts)
54+
5255
if self.config.multi_modal:
5356
texts_input = [
5457
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts

src/memos/embedders/base.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,105 @@
1+
import re
2+
13
from abc import ABC, abstractmethod
24

35
from memos.configs.embedder import BaseEmbedderConfig
46

57

8+
def _count_tokens_for_embedding(text: str) -> int:
9+
"""
10+
Count tokens in text for embedding truncation.
11+
Uses tiktoken if available, otherwise falls back to heuristic.
12+
13+
Args:
14+
text: Text to count tokens for.
15+
16+
Returns:
17+
Number of tokens.
18+
"""
19+
try:
20+
import tiktoken
21+
22+
try:
23+
enc = tiktoken.encoding_for_model("gpt-4o-mini")
24+
except Exception:
25+
enc = tiktoken.get_encoding("cl100k_base")
26+
return len(enc.encode(text or ""))
27+
except Exception:
28+
# Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars
29+
if not text:
30+
return 0
31+
zh_chars = re.findall(r"[\u4e00-\u9fff]", text)
32+
zh = len(zh_chars)
33+
rest = len(text) - zh
34+
return zh + max(1, rest // 4)
35+
36+
37+
def _truncate_text_to_tokens(text: str, max_tokens: int) -> str:
38+
"""
39+
Truncate text to fit within max_tokens limit.
40+
Uses binary search to find the optimal truncation point.
41+
42+
Args:
43+
text: Text to truncate.
44+
max_tokens: Maximum number of tokens allowed.
45+
46+
Returns:
47+
Truncated text.
48+
"""
49+
if not text or max_tokens is None or max_tokens <= 0:
50+
return text
51+
52+
current_tokens = _count_tokens_for_embedding(text)
53+
if current_tokens <= max_tokens:
54+
return text
55+
56+
# Binary search for the right truncation point
57+
low, high = 0, len(text)
58+
best_text = ""
59+
60+
while low < high:
61+
mid = (low + high + 1) // 2 # Use +1 to avoid infinite loop
62+
truncated = text[:mid]
63+
tokens = _count_tokens_for_embedding(truncated)
64+
65+
if tokens <= max_tokens:
66+
best_text = truncated
67+
low = mid
68+
else:
69+
high = mid - 1
70+
71+
return best_text if best_text else text[:1] # Fallback to at least one character
72+
73+
674
class BaseEmbedder(ABC):
775
"""Base class for all Embedding models."""
876

977
@abstractmethod
1078
def __init__(self, config: BaseEmbedderConfig):
1179
"""Initialize the embedding model with the given configuration."""
80+
self.config = config
81+
82+
def _truncate_texts(self, texts: list[str], approx_char_per_token=1.1) -> (list)[str]:
83+
"""
84+
Truncate texts to fit within max_tokens limit if configured.
85+
86+
Args:
87+
texts: List of texts to truncate.
88+
89+
Returns:
90+
List of truncated texts.
91+
"""
92+
if not hasattr(self, "config") or self.config.max_tokens is None:
93+
return texts
94+
max_tokens = self.config.max_tokens
95+
96+
truncated = []
97+
for t in texts:
98+
if len(t) < max_tokens * approx_char_per_token:
99+
truncated.append(t)
100+
else:
101+
truncated.append(_truncate_text_to_tokens(t, max_tokens))
102+
return truncated
12103

13104
@abstractmethod
14105
def embed(self, texts: list[str]) -> list[list[float]]:

src/memos/embedders/ollama.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ def embed(self, texts: list[str]) -> list[list[float]]:
6767
Returns:
6868
List of embeddings, each represented as a list of floats.
6969
"""
70+
# Truncate texts if max_tokens is configured
71+
texts = self._truncate_texts(texts)
72+
7073
response = self.client.embed(
7174
model=self.config.model_name_or_path,
7275
input=texts,

src/memos/embedders/sentence_transformer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,8 @@ def embed(self, texts: list[str]) -> list[list[float]]:
4242
Returns:
4343
List of embeddings, each represented as a list of floats.
4444
"""
45+
# Truncate texts if max_tokens is configured
46+
texts = self._truncate_texts(texts)
47+
4548
embeddings = self.model.encode(texts, convert_to_numpy=True)
4649
return embeddings.tolist()

src/memos/embedders/universal_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
3636
log_extra_args={"model_name_or_path": "text-embedding-3-large"},
3737
)
3838
def embed(self, texts: list[str]) -> list[list[float]]:
39+
# Truncate texts if max_tokens is configured
40+
texts = self._truncate_texts(texts)
41+
3942
if self.provider == "openai" or self.provider == "azure":
4043
try:
4144
response = self.client.embeddings.create(

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,61 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
4747
direct_markdown_hostnames=direct_markdown_hostnames,
4848
)
4949

50+
def _split_large_memory_item(
51+
self, item: TextualMemoryItem, max_tokens: int
52+
) -> list[TextualMemoryItem]:
53+
"""
54+
Split a single memory item that exceeds max_tokens into multiple chunks.
55+
56+
Args:
57+
item: TextualMemoryItem to split
58+
max_tokens: Maximum tokens per chunk
59+
60+
Returns:
61+
List of TextualMemoryItem chunks
62+
"""
63+
item_text = item.memory or ""
64+
if not item_text:
65+
return [item]
66+
67+
item_tokens = self._count_tokens(item_text)
68+
if item_tokens <= max_tokens:
69+
return [item]
70+
71+
# Use chunker to split the text
72+
try:
73+
chunks = self.chunker.chunk(item_text)
74+
split_items = []
75+
76+
for chunk in chunks:
77+
# Chunk objects have a 'text' attribute
78+
chunk_text = chunk.text
79+
if not chunk_text or not chunk_text.strip():
80+
continue
81+
82+
# Create a new memory item for each chunk, preserving original metadata
83+
split_item = self._make_memory_item(
84+
value=chunk_text,
85+
info={
86+
"user_id": item.metadata.user_id,
87+
"session_id": item.metadata.session_id,
88+
**(item.metadata.info or {}),
89+
},
90+
memory_type=item.metadata.memory_type,
91+
tags=item.metadata.tags or [],
92+
key=item.metadata.key,
93+
sources=item.metadata.sources or [],
94+
background=item.metadata.background or "",
95+
)
96+
split_items.append(split_item)
97+
98+
return split_items if split_items else [item]
99+
except Exception as e:
100+
logger.warning(
101+
f"[MultiModalStruct] Failed to split large memory item: {e}. Returning original item."
102+
)
103+
return [item]
104+
50105
def _concat_multi_modal_memories(
51106
self, all_memory_items: list[TextualMemoryItem], max_tokens=None, overlap=200
52107
) -> list[TextualMemoryItem]:
@@ -57,35 +112,49 @@ def _concat_multi_modal_memories(
57112
2. Each window has overlap tokens for context continuity
58113
3. Aggregates items within each window into a single memory item
59114
4. Determines memory_type based on roles in each window
115+
5. Splits single large memory items that exceed max_tokens
60116
"""
61117
if not all_memory_items:
62118
return []
63119

64-
# If only one item, return as-is (no need to aggregate)
65-
if len(all_memory_items) == 1:
66-
return all_memory_items
67-
68120
max_tokens = max_tokens or self.chat_window_max_tokens
121+
122+
# Split large memory items before processing
123+
processed_items = []
124+
for item in all_memory_items:
125+
item_text = item.memory or ""
126+
item_tokens = self._count_tokens(item_text)
127+
if item_tokens > max_tokens:
128+
# Split the large item into multiple chunks
129+
split_items = self._split_large_memory_item(item, max_tokens)
130+
processed_items.extend(split_items)
131+
else:
132+
processed_items.append(item)
133+
134+
# If only one item after processing, return as-is
135+
if len(processed_items) == 1:
136+
return processed_items
137+
69138
windows = []
70139
buf_items = []
71140
cur_text = ""
72141

73142
# Extract info from first item (all items should have same user_id, session_id)
74-
first_item = all_memory_items[0]
143+
first_item = processed_items[0]
75144
info = {
76145
"user_id": first_item.metadata.user_id,
77146
"session_id": first_item.metadata.session_id,
78147
**(first_item.metadata.info or {}),
79148
}
80149

81-
for _idx, item in enumerate(all_memory_items):
150+
for _idx, item in enumerate(processed_items):
82151
item_text = item.memory or ""
83152
# Ensure line ends with newline (same format as simple_struct)
84153
line = item_text if item_text.endswith("\n") else f"{item_text}\n"
85154

86155
# Check if adding this item would exceed max_tokens (same logic as _iter_chat_windows)
87-
# Note: The `and cur_text` condition ensures that single large messages are not truncated.
88-
# If cur_text is empty (new window), even if line exceeds max_tokens, it won't trigger output.
156+
# Note: After splitting large items, each item should be <= max_tokens,
157+
# but we still check to handle edge cases
89158
if self._count_tokens(cur_text + line) > max_tokens and cur_text:
90159
# Yield current window
91160
window = self._build_window_from_items(buf_items, info)
@@ -102,8 +171,7 @@ def _concat_multi_modal_memories(
102171
# Recalculate cur_text from remaining items
103172
cur_text = "".join([it.memory or "" for it in buf_items])
104173

105-
# Add item to current window (always, even if it exceeds max_tokens)
106-
# This ensures single large messages are not truncated, same as simple_struct
174+
# Add item to current window
107175
buf_items.append(item)
108176
# Recalculate cur_text from all items in buffer (same as _iter_chat_windows)
109177
cur_text = "".join([it.memory or "" for it in buf_items])
@@ -255,14 +323,12 @@ def _process_multi_modal_data(
255323
for msg in scene_data_info:
256324
items = self.multi_modal_parser.parse(msg, info, mode="fast", **kwargs)
257325
all_memory_items.extend(items)
258-
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
259-
260326
else:
261327
# Parse as single message
262-
fast_memory_items = self.multi_modal_parser.parse(
328+
all_memory_items = self.multi_modal_parser.parse(
263329
scene_data_info, info, mode="fast", **kwargs
264330
)
265-
331+
fast_memory_items = self._concat_multi_modal_memories(all_memory_items)
266332
if mode == "fast":
267333
return fast_memory_items
268334
else:

src/memos/mem_reader/read_multi_modal/base.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
TreeNodeTextualMemoryMetadata,
1717
)
1818

19+
from .utils import get_text_splitter
20+
1921

2022
logger = log.get_logger(__name__)
2123

@@ -223,3 +225,30 @@ def parse(
223225
return self.parse_fine(message, info, **kwargs)
224226
else:
225227
raise ValueError(f"Unknown mode: {mode}. Must be 'fast' or 'fine'")
228+
229+
def _split_text(self, text: str) -> list[str]:
230+
"""
231+
Split text into chunks using text splitter from utils.
232+
233+
Args:
234+
text: Text to split
235+
236+
Returns:
237+
List of text chunks
238+
"""
239+
if not text or not text.strip():
240+
return []
241+
242+
splitter = get_text_splitter()
243+
if not splitter:
244+
# If text splitter is not available, return text as single chunk
245+
return [text] if text.strip() else []
246+
247+
try:
248+
chunks = splitter.split_text(text)
249+
logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks")
250+
return chunks
251+
except Exception as e:
252+
logger.error(f"[FileContentParser] Error splitting text: {e}")
253+
# Fallback to single chunk
254+
return [text] if text.strip() else []

src/memos/mem_reader/read_multi_modal/file_content_parser.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from memos.types.openai_chat_completion_types import File
1717

1818
from .base import BaseMessageParser, _derive_key
19-
from .utils import get_parser, get_text_splitter
19+
from .utils import get_parser
2020

2121

2222
logger = get_logger(__name__)
@@ -108,33 +108,6 @@ def __init__(
108108
else:
109109
self.direct_markdown_hostnames = []
110110

111-
def _split_text(self, text: str) -> list[str]:
112-
"""
113-
Split text into chunks using text splitter from utils.
114-
115-
Args:
116-
text: Text to split
117-
118-
Returns:
119-
List of text chunks
120-
"""
121-
if not text or not text.strip():
122-
return []
123-
124-
splitter = get_text_splitter()
125-
if not splitter:
126-
# If text splitter is not available, return text as single chunk
127-
return [text] if text.strip() else []
128-
129-
try:
130-
chunks = splitter.split_text(text)
131-
logger.debug(f"[FileContentParser] Split text into {len(chunks)} chunks")
132-
return chunks
133-
except Exception as e:
134-
logger.error(f"[FileContentParser] Error splitting text: {e}")
135-
# Fallback to single chunk
136-
return [text] if text.strip() else []
137-
138111
def create_source(
139112
self,
140113
message: File,

0 commit comments

Comments
 (0)