Skip to content

Commit 74e0247

Browse files
authored
feat: add headers for embedding and reranker (#519)
1 parent 8003c2f commit 74e0247

File tree

5 files changed

+14
-6
lines changed

5 files changed

+14
-6
lines changed

src/memos/api/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def get_reranker_config() -> dict[str, Any]:
381381
"url": os.getenv("MOS_RERANKER_URL"),
382382
"model": os.getenv("MOS_RERANKER_MODEL", "bge-reranker-v2-m3"),
383383
"timeout": 10,
384-
"headers_extra": os.getenv("MOS_RERANKER_HEADERS_EXTRA"),
384+
"headers_extra": json.loads(os.getenv("MOS_RERANKER_HEADERS_EXTRA", "{}")),
385385
"rerank_source": os.getenv("MOS_RERANK_SOURCE"),
386386
"reranker_strategy": os.getenv("MOS_RERANKER_STRATEGY", "single_turn"),
387387
},
@@ -407,6 +407,7 @@ def get_embedder_config() -> dict[str, Any]:
407407
"provider": os.getenv("MOS_EMBEDDER_PROVIDER", "openai"),
408408
"api_key": os.getenv("MOS_EMBEDDER_API_KEY", "sk-xxxx"),
409409
"model_name_or_path": os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"),
410+
"headers_extra": json.loads(os.getenv("MOS_EMBEDDER_HEADERS_EXTRA", "{}")),
410411
"base_url": os.getenv("MOS_EMBEDDER_API_BASE", "http://openai.com"),
411412
},
412413
}

src/memos/configs/embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ class BaseEmbedderConfig(BaseConfig):
1212
embedding_dims: int | None = Field(
1313
default=None, description="Number of dimensions for the embedding"
1414
)
15+
headers_extra: dict[str, Any] | None = Field(
16+
default=None,
17+
description="Extra headers for the embedding model, only for universal_api backend",
18+
)
1519

1620

1721
class OllamaEmbedderConfig(BaseEmbedderConfig):

src/memos/embedders/universal_api.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
1616
self.config = config
1717

1818
if self.provider == "openai":
19-
self.client = OpenAIClient(api_key=config.api_key, base_url=config.base_url)
19+
self.client = OpenAIClient(
20+
api_key=config.api_key,
21+
base_url=config.base_url,
22+
default_headers=config.headers_extra if config.headers_extra else None,
23+
)
2024
elif self.provider == "azure":
2125
self.client = AzureClient(
2226
azure_endpoint=config.base_url,

tests/configs/test_embedder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_base_embedder_config():
1717
required_fields=[
1818
"model_name_or_path",
1919
],
20-
optional_fields=["embedding_dims"],
20+
optional_fields=["embedding_dims", "headers_extra"],
2121
)
2222

2323
check_config_instantiation_valid(
@@ -36,7 +36,7 @@ def test_ollama_embedder_config():
3636
required_fields=[
3737
"model_name_or_path",
3838
],
39-
optional_fields=["embedding_dims", "api_base"],
39+
optional_fields=["embedding_dims", "headers_extra", "api_base"],
4040
)
4141

4242
check_config_instantiation_valid(

tests/embedders/test_universal_api.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ def test_embed_single_text(self, mock_openai_client):
2828

2929
# Assert OpenAIClient was created with proper args
3030
mock_openai_client.assert_called_once_with(
31-
api_key="fake-api-key",
32-
base_url="https://api.openai.com/v1",
31+
api_key="fake-api-key", base_url="https://api.openai.com/v1", default_headers=None
3332
)
3433

3534
# Assert embeddings.create called with correct params

0 commit comments

Comments
 (0)