Skip to content

Commit c90f241

Browse files
authored
feat: add ark/volcengine embedder support (#57)
* feat: add volcengine embedding support * test: add volcengine embedding test * feat(embedder): add support for ArkEmbedder * fixed: minor typo fix and remove unnecessary function * test: add image embedding test * fixed: remove images embed support and test * fixed: change chunk_size default value
1 parent 2189062 commit c90f241

File tree

7 files changed

+166
-3
lines changed

7 files changed

+166
-3
lines changed

poetry.lock

Lines changed: 21 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ sentence-transformers = "^4.1.0"
2727
sqlalchemy = "^2.0.41"
2828
redis = "^6.2.0"
2929
schedule = "^1.2.2"
30+
volcengine-python-sdk = "^4.0.4"
3031

3132
[tool.poetry.group.dev]
3233
optional = false

src/memos/configs/embedder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ class OllamaEmbedderConfig(BaseEmbedderConfig):
1818
api_base: str = Field(default="http://localhost:11434", description="Base URL for Ollama API")
1919

2020

21+
class ArkEmbedderConfig(BaseEmbedderConfig):
22+
api_key: str = Field(..., description="Ark API key")
23+
api_base: str = Field(
24+
default="https://ark.cn-beijing.volces.com/api/v3/", description="Base URL for Ark API"
25+
)
26+
chunk_size: int = Field(default=1, description="Chunk size for Ark API")
27+
28+
2129
class SenTranEmbedderConfig(BaseEmbedderConfig):
2230
"""Configuration class for Sentence Transformer embeddings."""
2331

@@ -36,6 +44,7 @@ class EmbedderConfigFactory(BaseConfig):
3644
backend_to_class: ClassVar[dict[str, Any]] = {
3745
"ollama": OllamaEmbedderConfig,
3846
"sentence_transformer": SenTranEmbedderConfig,
47+
"ark": ArkEmbedderConfig,
3948
}
4049

4150
@field_validator("backend")

src/memos/embedders/ark.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from volcenginesdkarkruntime import Ark
2+
from volcenginesdkarkruntime.types.multimodal_embedding import (
3+
EmbeddingInputParam,
4+
MultimodalEmbeddingContentPartTextParam,
5+
MultimodalEmbeddingResponse,
6+
)
7+
8+
from memos.configs.embedder import ArkEmbedderConfig
9+
from memos.embedders.base import BaseEmbedder
10+
from memos.log import get_logger
11+
12+
13+
logger = get_logger(__name__)
14+
15+
16+
class ArkEmbedder(BaseEmbedder):
17+
"""Ark Embedder class."""
18+
19+
def __init__(self, config: ArkEmbedderConfig):
20+
self.config = config
21+
22+
if self.config.embedding_dims is not None:
23+
logger.warning(
24+
"Ark does not support specifying embedding dimensions. "
25+
"The embedding dimensions is determined by the model."
26+
"`embedding_dims` will be set to None."
27+
)
28+
self.config.embedding_dims = None
29+
30+
# Default model if not specified
31+
if not self.config.model_name_or_path:
32+
self.config.model_name_or_path = "doubao-embedding-vision-250615"
33+
34+
# Initialize ark client
35+
self.client = Ark(api_key=self.config.api_key, base_url=self.config.api_base)
36+
37+
def embed(self, texts: list[str]) -> list[list[float]]:
38+
"""
39+
Generate embeddings for the given texts.
40+
41+
Args:
42+
texts: List of texts to embed.
43+
44+
Returns:
45+
List of embeddings, each represented as a list of floats.
46+
"""
47+
texts_input = [
48+
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
49+
]
50+
return self.multimodal_embeddings(texts_input, chunk_size=self.config.chunk_size)
51+
52+
def multimodal_embeddings(
53+
self, inputs: list[EmbeddingInputParam], chunk_size: int | None = None
54+
) -> list[list[float]]:
55+
chunk_size_ = chunk_size or self.config.chunk_size
56+
embeddings: list[list[float]] = []
57+
58+
for i in range(0, len(inputs), chunk_size_):
59+
response: MultimodalEmbeddingResponse = self.client.multimodal_embeddings.create(
60+
model=self.config.model_name_or_path,
61+
input=inputs[i : i + chunk_size_],
62+
)
63+
64+
data = [response.data] if isinstance(response.data, dict) else response.data
65+
embeddings.extend(r["embedding"] for r in data)
66+
67+
return embeddings

src/memos/embedders/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, ClassVar
22

33
from memos.configs.embedder import EmbedderConfigFactory
4+
from memos.embedders.ark import ArkEmbedder
45
from memos.embedders.base import BaseEmbedder
56
from memos.embedders.ollama import OllamaEmbedder
67
from memos.embedders.sentence_transformer import SenTranEmbedder
@@ -12,6 +13,7 @@ class EmbedderFactory(BaseEmbedder):
1213
backend_to_class: ClassVar[dict[str, Any]] = {
1314
"ollama": OllamaEmbedder,
1415
"sentence_transformer": SenTranEmbedder,
16+
"ark": ArkEmbedder,
1517
}
1618

1719
@classmethod

src/memos/memories/textual/general.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tenacity import retry, retry_if_exception_type, stop_after_attempt
88

99
from memos.configs.memory import GeneralTextMemoryConfig
10-
from memos.embedders.factory import EmbedderFactory, OllamaEmbedder
10+
from memos.embedders.factory import ArkEmbedder, EmbedderFactory, OllamaEmbedder
1111
from memos.llms.factory import LLMFactory, OllamaLLM, OpenAILLM
1212
from memos.log import get_logger
1313
from memos.memories.textual.base import BaseTextMemory
@@ -28,7 +28,7 @@ def __init__(self, config: GeneralTextMemoryConfig):
2828
self.config: GeneralTextMemoryConfig = config
2929
self.extractor_llm: OpenAILLM | OllamaLLM = LLMFactory.from_config(config.extractor_llm)
3030
self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db)
31-
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
31+
self.embedder: OllamaEmbedder | ArkEmbedder = EmbedderFactory.from_config(config.embedder)
3232

3333
@retry(
3434
stop=stop_after_attempt(3),

tests/embedders/test_ark.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import unittest
2+
3+
from unittest.mock import patch
4+
5+
from memos.configs.embedder import EmbedderConfigFactory
6+
from memos.embedders.factory import ArkEmbedder, EmbedderFactory
7+
8+
9+
class TestEmbedderFactory(unittest.TestCase):
10+
@patch.object(ArkEmbedder, "embed")
11+
def test_embed_single_text(self, mock_embed):
12+
"""Test embedding a single text."""
13+
mock_embed.return_value = [[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]
14+
15+
config = EmbedderConfigFactory.model_validate(
16+
{
17+
"backend": "ark",
18+
"config": {
19+
"model_name_or_path": "doubao-embedding-vision-250615",
20+
"embedding_dims": 2048,
21+
"api_key": "your-api-key",
22+
"api_base": "https://ark.cn-beijing.volces.com/api/v3",
23+
},
24+
}
25+
)
26+
embedder = EmbedderFactory.from_config(config)
27+
text = "This is a sample text for embedding generation."
28+
result = embedder.embed([text])
29+
30+
mock_embed.assert_called_once_with([text])
31+
self.assertEqual(len(result[0]), 6)
32+
33+
@patch.object(ArkEmbedder, "embed")
34+
def test_embed_batch_text(self, mock_embed):
35+
"""Test embedding multiple texts at once."""
36+
mock_embed.return_value = [
37+
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
38+
[0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
39+
[0.3, 0.4, 0.5, 0.6, 0.1, 0.2],
40+
]
41+
42+
config = EmbedderConfigFactory.model_validate(
43+
{
44+
"backend": "ark",
45+
"config": {
46+
"model_name_or_path": "doubao-embedding-vision-250615",
47+
"embedding_dims": 2048,
48+
"api_key": "your-api-key",
49+
"api_base": "https://ark.cn-beijing.volces.com/api/v3",
50+
},
51+
}
52+
)
53+
embedder = EmbedderFactory.from_config(config)
54+
texts = [
55+
"First sample text for batch embedding.",
56+
"Second sample text for batch embedding.",
57+
"Third sample text for batch embedding.",
58+
]
59+
60+
result = embedder.embed(texts)
61+
62+
mock_embed.assert_called_once_with(texts)
63+
self.assertEqual(len(result), 3)
64+
self.assertEqual(len(result[0]), 6)

0 commit comments

Comments
 (0)