|
| 1 | +import re |
| 2 | + |
1 | 3 | from abc import ABC, abstractmethod |
2 | 4 |
|
3 | 5 | from memos.configs.embedder import BaseEmbedderConfig |
4 | 6 |
|
5 | 7 |
|
| 8 | +def _count_tokens_for_embedding(text: str) -> int: |
| 9 | + """ |
| 10 | + Count tokens in text for embedding truncation. |
| 11 | + Uses tiktoken if available, otherwise falls back to heuristic. |
| 12 | +
|
| 13 | + Args: |
| 14 | + text: Text to count tokens for. |
| 15 | +
|
| 16 | + Returns: |
| 17 | + Number of tokens. |
| 18 | + """ |
| 19 | + try: |
| 20 | + import tiktoken |
| 21 | + |
| 22 | + try: |
| 23 | + enc = tiktoken.encoding_for_model("gpt-4o-mini") |
| 24 | + except Exception: |
| 25 | + enc = tiktoken.get_encoding("cl100k_base") |
| 26 | + return len(enc.encode(text or "")) |
| 27 | + except Exception: |
| 28 | + # Heuristic fallback: zh chars ~1 token, others ~1 token per ~4 chars |
| 29 | + if not text: |
| 30 | + return 0 |
| 31 | + zh_chars = re.findall(r"[\u4e00-\u9fff]", text) |
| 32 | + zh = len(zh_chars) |
| 33 | + rest = len(text) - zh |
| 34 | + return zh + max(1, rest // 4) |
| 35 | + |
| 36 | + |
| 37 | +def _truncate_text_to_tokens(text: str, max_tokens: int) -> str: |
| 38 | + """ |
| 39 | + Truncate text to fit within max_tokens limit. |
| 40 | + Uses binary search to find the optimal truncation point. |
| 41 | +
|
| 42 | + Args: |
| 43 | + text: Text to truncate. |
| 44 | + max_tokens: Maximum number of tokens allowed. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + Truncated text. |
| 48 | + """ |
| 49 | + if not text or max_tokens is None or max_tokens <= 0: |
| 50 | + return text |
| 51 | + |
| 52 | + current_tokens = _count_tokens_for_embedding(text) |
| 53 | + if current_tokens <= max_tokens: |
| 54 | + return text |
| 55 | + |
| 56 | + # Binary search for the right truncation point |
| 57 | + low, high = 0, len(text) |
| 58 | + best_text = "" |
| 59 | + |
| 60 | + while low < high: |
| 61 | + mid = (low + high + 1) // 2 # Use +1 to avoid infinite loop |
| 62 | + truncated = text[:mid] |
| 63 | + tokens = _count_tokens_for_embedding(truncated) |
| 64 | + |
| 65 | + if tokens <= max_tokens: |
| 66 | + best_text = truncated |
| 67 | + low = mid |
| 68 | + else: |
| 69 | + high = mid - 1 |
| 70 | + |
| 71 | + return best_text if best_text else text[:1] # Fallback to at least one character |
| 72 | + |
| 73 | + |
6 | 74 | class BaseEmbedder(ABC): |
7 | 75 | """Base class for all Embedding models.""" |
8 | 76 |
|
9 | 77 | @abstractmethod |
10 | 78 | def __init__(self, config: BaseEmbedderConfig): |
11 | 79 | """Initialize the embedding model with the given configuration.""" |
| 80 | + self.config = config |
| 81 | + |
| 82 | + def _truncate_texts(self, texts: list[str], approx_char_per_token=1.1) -> (list)[str]: |
| 83 | + """ |
| 84 | + Truncate texts to fit within max_tokens limit if configured. |
| 85 | +
|
| 86 | + Args: |
| 87 | + texts: List of texts to truncate. |
| 88 | +
|
| 89 | + Returns: |
| 90 | + List of truncated texts. |
| 91 | + """ |
| 92 | + if not hasattr(self, "config") or self.config.max_tokens is None: |
| 93 | + return texts |
| 94 | + max_tokens = self.config.max_tokens |
| 95 | + |
| 96 | + truncated = [] |
| 97 | + for t in texts: |
| 98 | + if len(t) < max_tokens * approx_char_per_token: |
| 99 | + truncated.append(t) |
| 100 | + else: |
| 101 | + truncated.append(_truncate_text_to_tokens(t, max_tokens)) |
| 102 | + return truncated |
12 | 103 |
|
13 | 104 | @abstractmethod |
14 | 105 | def embed(self, texts: list[str]) -> list[list[float]]: |
|
0 commit comments