Skip to content

Commit ba74bf1

Browse files
authored
Add More backend (#99)
* 1. 豆包支持非多模态模型 2. embedder支持Azure backend * fix typo * add example
1 parent a9ffe4f commit ba74bf1

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

examples/basic_modules/embedder.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
print("Scenario 3 HF embedding shape:", len(embedding_hf[0]))
4949
print("==" * 20)
5050

51-
# === Scenario 4: Using UniversalAPIEmbedder ===
51+
# === Scenario 4: Using UniversalAPIEmbedder(OpenAI) ===
5252

5353
config_api = EmbedderConfigFactory.model_validate(
5454
{
@@ -66,3 +66,22 @@
6666
embedding_api = embedder_api.embed([text_api])
6767
print("Scenario 4: OpenAI API embedding vector length:", len(embedding_api[0]))
6868
print("Embedding preview:", embedding_api[0][:10])
69+
70+
# === Scenario 5: Using UniversalAPIEmbedder(Azure) ===
71+
72+
config_api = EmbedderConfigFactory.model_validate(
73+
{
74+
"backend": "universal_api",
75+
"config": {
76+
"provider": "azure",
77+
"api_key": "<YOUR_AZURE_KEY>",
78+
"model_name_or_path": "text-embedding-3-large",
79+
"base_url": "https://open.azure.com/openapi/online/v2/",
80+
},
81+
}
82+
)
83+
embedder_api = EmbedderFactory.from_config(config_api)
84+
text_api = "This is a sample text for embedding generation using Azure API."
85+
embedding_api = embedder_api.embed([text_api])
86+
print("Scenario 5: Azure API embedding vector length:", len(embedding_api[0]))
87+
print("Embedding preview:", embedding_api[0][:10])

src/memos/configs/embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ class ArkEmbedderConfig(BaseEmbedderConfig):
2424
default="https://ark.cn-beijing.volces.com/api/v3/", description="Base URL for Ark API"
2525
)
2626
chunk_size: int = Field(default=1, description="Chunk size for Ark API")
27+
multi_modal: bool = Field(
28+
default=False,
29+
description="Whether to use multi-modal embedding (text + image) with Ark",
30+
)
2731

2832

2933
class SenTranEmbedderConfig(BaseEmbedderConfig):

src/memos/embedders/ark.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,26 @@ def embed(self, texts: list[str]) -> list[list[float]]:
4444
Returns:
4545
List of embeddings, each represented as a list of floats.
4646
"""
47-
texts_input = [
48-
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
49-
]
50-
return self.multimodal_embeddings(texts_input, chunk_size=self.config.chunk_size)
47+
if self.config.multi_modal:
48+
texts_input = [
49+
MultimodalEmbeddingContentPartTextParam(text=text, type="text") for text in texts
50+
]
51+
return self.multimodal_embeddings(inputs=texts_input, chunk_size=self.config.chunk_size)
52+
return self.text_embedding(texts, chunk_size=self.config.chunk_size)
53+
54+
def text_embedding(self, inputs: list[str], chunk_size: int | None = None) -> list[list[float]]:
55+
chunk_size_ = chunk_size or self.config.chunk_size
56+
embeddings: list[list[float]] = []
57+
for i in range(0, len(inputs), chunk_size_):
58+
response = self.client.embeddings.create(
59+
model=self.config.model_name_or_path,
60+
input=inputs[i : i + chunk_size_],
61+
)
62+
63+
data = [response.data] if isinstance(response.data, dict) else response.data
64+
embeddings.extend(r.embedding for r in data)
65+
66+
return embeddings
5167

5268
def multimodal_embeddings(
5369
self, inputs: list[EmbeddingInputParam], chunk_size: int | None = None

src/memos/embedders/universal_api.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from openai import OpenAI as OpenAIClient
2+
from openai import AzureOpenAI as AzureClient
23

34
from memos.configs.embedder import UniversalAPIEmbedderConfig
45
from memos.embedders.base import BaseEmbedder
@@ -11,11 +12,17 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
1112

1213
if self.provider == "openai":
1314
self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url)
15+
elif self.provider == "azure":
16+
self.client = AzureClient(
17+
azure_endpoint=config.base_url,
18+
api_version="2024-03-01-preview",
19+
api_key=config.api_key,
20+
)
1421
else:
1522
raise ValueError(f"Unsupported provider: {self.provider}")
1623

1724
def embed(self, texts: list[str]) -> list[list[float]]:
18-
if self.provider == "openai":
25+
if self.provider == "openai" or self.provider == "azure":
1926
response = self.client.embeddings.create(
2027
model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"),
2128
input=texts,

0 commit comments

Comments
 (0)