Skip to content

Commit 244e748

Browse files
CaralHsifridayLCopilot
authored
feat: add universal api embedder (#81)
* feat: add api embedder * test: add unittest for universal api embedder * Update src/memos/embedders/universal_api.py Co-authored-by: Copilot <[email protected]> * Update src/memos/configs/embedder.py Co-authored-by: Copilot <[email protected]> * style: reformat * style: reformat --------- Co-authored-by: chunyu li <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 339c756 commit 244e748

File tree

6 files changed

+138
-1
lines changed

6 files changed

+138
-1
lines changed

evaluation/scripts/locomo/openai_memory_locomo_eval_guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ Can you please extract relevant information from this conversation and create me
8181
* Click on **Manage** in the memory confirmation to view the newly generated memories.
8282
* Create a new local `.txt` file with the same name as the input file (e.g., `0-D1.txt`).
8383
* Copy each memory entry from ChatGPT and paste it into the new file, with each memory on a new line.
84-
5. **Reset Memories for the Next Conversation:**
84+
5. **Reset Memories for the Next Conversation:**
8585
* Once all sessions for a conversation are complete, it is essential to **delete all memories to ensure a clean state for the next conversation**. Navigate to Settings -> Personalization -> Manage and click Delete all.
8686

8787
**Example Memory Output (`0-D9.txt`):**

examples/basic_modules/embedder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,23 @@
4646
text_hf = "This is a sample text for Hugging Face embedding generation."
4747
embedding_hf = embedder_hf.embed([text_hf])
4848
print("Scenario 3 HF embedding shape:", len(embedding_hf[0]))
49+
print("==" * 20)
50+
51+
# === Scenario 4: Using UniversalAPIEmbedder ===
52+
53+
config_api = EmbedderConfigFactory.model_validate(
54+
{
55+
"backend": "universal_api",
56+
"config": {
57+
"provider": "openai",
58+
"api_key": "<YOUR_KEY>",
59+
"model_name_or_path": "text-embedding-3-large",
60+
"base_url": "https://api.myproxy.com/v1",
61+
},
62+
}
63+
)
64+
embedder_api = EmbedderFactory.from_config(config_api)
65+
text_api = "This is a sample text for embedding generation using OpenAI API."
66+
embedding_api = embedder_api.embed([text_api])
67+
print("Scenario 4: OpenAI API embedding vector length:", len(embedding_api[0]))
68+
print("Embedding preview:", embedding_api[0][:10])

src/memos/configs/embedder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ class SenTranEmbedderConfig(BaseEmbedderConfig):
3535
)
3636

3737

38+
class UniversalAPIEmbedderConfig(BaseEmbedderConfig):
39+
"""
40+
Configuration class for universal API embedding providers, e.g.,
41+
OpenAI, etc.
42+
"""
43+
44+
provider: str = Field(..., description="Provider name, e.g., 'openai'")
45+
api_key: str = Field(..., description="API key for the embedding provider")
46+
base_url: str | None = Field(
47+
default=None, description="Optional base URL for custom or proxied endpoint"
48+
)
49+
50+
3851
class EmbedderConfigFactory(BaseConfig):
3952
"""Factory class for creating embedder configurations."""
4053

@@ -45,6 +58,7 @@ class EmbedderConfigFactory(BaseConfig):
4558
"ollama": OllamaEmbedderConfig,
4659
"sentence_transformer": SenTranEmbedderConfig,
4760
"ark": ArkEmbedderConfig,
61+
"universal_api": UniversalAPIEmbedderConfig,
4862
}
4963

5064
@field_validator("backend")

src/memos/embedders/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from memos.embedders.base import BaseEmbedder
66
from memos.embedders.ollama import OllamaEmbedder
77
from memos.embedders.sentence_transformer import SenTranEmbedder
8+
from memos.embedders.universal_api import UniversalAPIEmbedder
89

910

1011
class EmbedderFactory(BaseEmbedder):
@@ -14,6 +15,7 @@ class EmbedderFactory(BaseEmbedder):
1415
"ollama": OllamaEmbedder,
1516
"sentence_transformer": SenTranEmbedder,
1617
"ark": ArkEmbedder,
18+
"universal_api": UniversalAPIEmbedder,
1719
}
1820

1921
@classmethod
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from openai import OpenAI as OpenAIClient
2+
3+
from memos.configs.embedder import UniversalAPIEmbedderConfig
4+
from memos.embedders.base import BaseEmbedder
5+
6+
7+
class UniversalAPIEmbedder(BaseEmbedder):
8+
def __init__(self, config: UniversalAPIEmbedderConfig):
9+
self.provider = config.provider
10+
self.config = config
11+
12+
if self.provider == "openai":
13+
self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url)
14+
else:
15+
raise ValueError(f"Unsupported provider: {self.provider}")
16+
17+
def embed(self, texts: list[str]) -> list[list[float]]:
18+
if self.provider == "openai":
19+
response = self.client.embeddings.create(
20+
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
21+
input=texts,
22+
)
23+
return [r.embedding for r in response.data]
24+
else:
25+
raise ValueError(f"Unsupported provider: {self.provider}")
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import unittest
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
from memos.configs.embedder import UniversalAPIEmbedderConfig
6+
from memos.embedders.universal_api import UniversalAPIEmbedder
7+
8+
9+
class TestUniversalAPIEmbedder(unittest.TestCase):
10+
@patch("memos.embedders.universal_api.OpenAIClient")
11+
def test_embed_single_text(self, mock_openai_client):
12+
"""Test embedding a single text with OpenAI provider."""
13+
# Mock the embeddings.create return value
14+
mock_response = MagicMock()
15+
mock_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3, 0.4])]
16+
mock_openai_client.return_value.embeddings.create.return_value = mock_response
17+
18+
config = UniversalAPIEmbedderConfig(
19+
provider="openai",
20+
api_key="fake-api-key",
21+
base_url="https://api.openai.com/v1",
22+
model_name_or_path="text-embedding-3-large",
23+
)
24+
25+
embedder = UniversalAPIEmbedder(config)
26+
text = ["Test input for embedding."]
27+
result = embedder.embed(text)
28+
29+
# Assert OpenAIClient was created with proper args
30+
mock_openai_client.assert_called_once_with(
31+
api_key="fake-api-key",
32+
base_url="https://api.openai.com/v1",
33+
)
34+
35+
# Assert embeddings.create called with correct params
36+
embedder.client.embeddings.create.assert_called_once_with(
37+
model="text-embedding-3-large",
38+
input=text,
39+
)
40+
41+
self.assertEqual(len(result[0]), 4)
42+
43+
@patch("memos.embedders.universal_api.OpenAIClient")
44+
def test_embed_batch_text(self, mock_openai_client):
45+
"""Test embedding multiple texts at once with OpenAI provider."""
46+
# Mock response for multiple texts
47+
mock_response = MagicMock()
48+
mock_response.data = [
49+
MagicMock(embedding=[0.1, 0.2]),
50+
MagicMock(embedding=[0.3, 0.4]),
51+
MagicMock(embedding=[0.5, 0.6]),
52+
]
53+
mock_openai_client.return_value.embeddings.create.return_value = mock_response
54+
55+
config = UniversalAPIEmbedderConfig(
56+
provider="openai",
57+
api_key="fake-api-key",
58+
base_url="https://api.openai.com/v1",
59+
model_name_or_path="text-embedding-3-large",
60+
)
61+
62+
embedder = UniversalAPIEmbedder(config)
63+
texts = ["First text.", "Second text.", "Third text."]
64+
result = embedder.embed(texts)
65+
66+
embedder.client.embeddings.create.assert_called_once_with(
67+
model="text-embedding-3-large",
68+
input=texts,
69+
)
70+
71+
self.assertEqual(len(result), 3)
72+
self.assertEqual(result[0], [0.1, 0.2])
73+
74+
75+
if __name__ == "__main__":
76+
unittest.main()

0 commit comments

Comments
 (0)