Skip to content

Commit 546a426

Browse files
authored
add support for azure backend (#88)
## Description <!-- Please include a summary of the changes below; Fill in the issue number that this PR addresses (if applicable); Mention the person who will review this PR (if you know who it is); Replace (summary), (issue), and (reviewer) with the appropriate information (No parentheses). 请在下方填写更改的摘要; 填写此 PR 解决的问题编号(如果适用); 提及将审查此 PR 的人(如果您知道是谁); 替换 (summary)、(issue) 和 (reviewer) 为适当的信息(不带括号)。 --> Summary: (summary) Add azure backend support Fix: #(issue) Reviewer: @(reviewer) ## Checklist: - [ x] I have performed a self-review of my own code | 我已自行检查了自己的代码 - [x ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释 - [x ] I have added tests that prove my fix is effective or that my feature works | 我已添加测试以证明我的修复有效或功能正常 - [ ] I have added necessary documentation (if applicable) | 我已添加必要的文档(如果适用) - [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用) - [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人
2 parents 0dde160 + 6e0a123 commit 546a426

File tree

8 files changed

+61
-13
lines changed

8 files changed

+61
-13
lines changed

src/memos/configs/llm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ class OpenAILLMConfig(BaseLLMConfig):
2626
)
2727

2828

29+
class AzureLLMConfig(BaseLLMConfig):
30+
base_url: str = Field(
31+
default="https://api.openai.azure.com/",
32+
description="Base URL for Azure OpenAI API",
33+
)
34+
api_version: str = Field(
35+
default="2024-03-01-preview",
36+
description="API version for Azure OpenAI",
37+
)
38+
api_key: str = Field(..., description="API key for Azure OpenAI")
39+
40+
2941
class OllamaLLMConfig(BaseLLMConfig):
3042
api_base: str = Field(
3143
default="http://localhost:11434",
@@ -61,6 +73,7 @@ class LLMConfigFactory(BaseConfig):
6173
backend_to_class: ClassVar[dict[str, Any]] = {
6274
"openai": OpenAILLMConfig,
6375
"ollama": OllamaLLMConfig,
76+
"azure": AzureLLMConfig,
6477
"huggingface": HFLLMConfig,
6578
"vllm": VLLMLLMConfig,
6679
"huggingface_singleton": HFLLMConfig, # Add singleton support

src/memos/llms/factory.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from memos.llms.hf import HFLLM
66
from memos.llms.hf_singleton import HFSingletonLLM
77
from memos.llms.ollama import OllamaLLM
8-
from memos.llms.openai import OpenAILLM
8+
from memos.llms.openai import AzureLLM, OpenAILLM
99
from memos.llms.vllm import VLLMLLM
1010

1111

@@ -14,6 +14,7 @@ class LLMFactory(BaseLLM):
1414

1515
backend_to_class: ClassVar[dict[str, Any]] = {
1616
"openai": OpenAILLM,
17+
"azure": AzureLLM,
1718
"ollama": OllamaLLM,
1819
"huggingface": HFLLM,
1920
"huggingface_singleton": HFSingletonLLM, # Add singleton version

src/memos/llms/openai.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import openai
22

3-
from memos.configs.llm import OpenAILLMConfig
3+
from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig
44
from memos.llms.base import BaseLLM
55
from memos.llms.utils import remove_thinking_tags
66
from memos.log import get_logger
@@ -32,3 +32,31 @@ def generate(self, messages: MessageList) -> str:
3232
return remove_thinking_tags(response_content)
3333
else:
3434
return response_content
35+
36+
37+
class AzureLLM(BaseLLM):
38+
"""Azure OpenAI LLM class."""
39+
40+
def __init__(self, config: AzureLLMConfig):
41+
self.config = config
42+
self.client = openai.AzureOpenAI(
43+
azure_endpoint=config.base_url,
44+
api_version=config.api_version,
45+
api_key=config.api_key,
46+
)
47+
48+
def generate(self, messages: MessageList) -> str:
49+
"""Generate a response from Azure OpenAI LLM."""
50+
response = self.client.chat.completions.create(
51+
model=self.config.model_name_or_path,
52+
messages=messages,
53+
temperature=self.config.temperature,
54+
max_tokens=self.config.max_tokens,
55+
top_p=self.config.top_p,
56+
)
57+
logger.info(f"Response from Azure OpenAI: {response.model_dump_json()}")
58+
response_content = response.choices[0].message.content
59+
if self.config.remove_think_prefix:
60+
return remove_thinking_tags(response_content)
61+
else:
62+
return response_content

src/memos/memories/textual/general.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from memos.configs.memory import GeneralTextMemoryConfig
1010
from memos.embedders.factory import EmbedderFactory, OllamaEmbedder
11-
from memos.llms.factory import LLMFactory, OllamaLLM, OpenAILLM
11+
from memos.llms.factory import LLMFactory, OllamaLLM, OpenAILLM, AzureLLM
1212
from memos.log import get_logger
1313
from memos.memories.textual.base import BaseTextMemory
1414
from memos.memories.textual.item import TextualMemoryItem
@@ -26,7 +26,9 @@ class GeneralTextMemory(BaseTextMemory):
2626
def __init__(self, config: GeneralTextMemoryConfig):
2727
"""Initialize memory with the given configuration."""
2828
self.config: GeneralTextMemoryConfig = config
29-
self.extractor_llm: OpenAILLM | OllamaLLM = LLMFactory.from_config(config.extractor_llm)
29+
self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
30+
config.extractor_llm
31+
)
3032
self.vector_db: QdrantVecDB = VecDBFactory.from_config(config.vector_db)
3133
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
3234

src/memos/memories/textual/tree.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from memos.configs.memory import TreeTextMemoryConfig
1111
from memos.embedders.factory import EmbedderFactory, OllamaEmbedder
1212
from memos.graph_dbs.factory import GraphStoreFactory, Neo4jGraphDB
13-
from memos.llms.factory import LLMFactory, OllamaLLM, OpenAILLM
13+
from memos.llms.factory import LLMFactory, OllamaLLM, OpenAILLM, AzureLLM
1414
from memos.log import get_logger
1515
from memos.memories.textual.base import BaseTextMemory
1616
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
@@ -31,8 +31,12 @@ class TreeTextMemory(BaseTextMemory):
3131
def __init__(self, config: TreeTextMemoryConfig):
3232
"""Initialize memory with the given configuration."""
3333
self.config: TreeTextMemoryConfig = config
34-
self.extractor_llm: OpenAILLM | OllamaLLM = LLMFactory.from_config(config.extractor_llm)
35-
self.dispatcher_llm: OpenAILLM | OllamaLLM = LLMFactory.from_config(config.dispatcher_llm)
34+
self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
35+
config.extractor_llm
36+
)
37+
self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config(
38+
config.dispatcher_llm
39+
)
3640
self.embedder: OllamaEmbedder = EmbedderFactory.from_config(config.embedder)
3741
self.graph_store: Neo4jGraphDB = GraphStoreFactory.from_config(config.graph_db)
3842
self.is_reorganize = config.reorganize

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from memos.embedders.factory import OllamaEmbedder
77
from memos.graph_dbs.neo4j import Neo4jGraphDB
8-
from memos.llms.factory import OllamaLLM, OpenAILLM
8+
from memos.llms.factory import OllamaLLM, OpenAILLM, AzureLLM
99
from memos.log import get_logger
1010
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
1111
from memos.memories.textual.tree_text_memory.organize.reorganizer import (
@@ -22,7 +22,7 @@ def __init__(
2222
self,
2323
graph_store: Neo4jGraphDB,
2424
embedder: OllamaEmbedder,
25-
llm: OpenAILLM | OllamaLLM,
25+
llm: OpenAILLM | OllamaLLM | AzureLLM,
2626
memory_size: dict | None = None,
2727
threshold: float | None = 0.80,
2828
merged_threshold: float | None = 0.92,

src/memos/memories/textual/tree_text_memory/retrieve/reranker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22

33
from memos.embedders.factory import OllamaEmbedder
4-
from memos.llms.factory import OllamaLLM, OpenAILLM
4+
from memos.llms.factory import OllamaLLM, OpenAILLM, AzureLLM
55
from memos.memories.textual.item import TextualMemoryItem
66
from memos.memories.textual.tree_text_memory.retrieve.retrieval_mid_structs import ParsedTaskGoal
77

@@ -41,7 +41,7 @@ class MemoryReranker:
4141
Rank retrieved memory cards by structural priority and contextual similarity.
4242
"""
4343

44-
def __init__(self, llm: OpenAILLM | OllamaLLM, embedder: OllamaEmbedder):
44+
def __init__(self, llm: OpenAILLM | OllamaLLM | AzureLLM, embedder: OllamaEmbedder):
4545
self.llm = llm
4646
self.embedder = embedder
4747

src/memos/memories/textual/tree_text_memory/retrieve/searcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from memos.embedders.factory import OllamaEmbedder
77
from memos.graph_dbs.factory import Neo4jGraphDB
8-
from memos.llms.factory import OllamaLLM, OpenAILLM
8+
from memos.llms.factory import OllamaLLM, OpenAILLM, AzureLLM
99
from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem
1010

1111
from .internet_retriever_factory import InternetRetrieverFactory
@@ -18,7 +18,7 @@
1818
class Searcher:
1919
def __init__(
2020
self,
21-
dispatcher_llm: OpenAILLM | OllamaLLM,
21+
dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM,
2222
graph_store: Neo4jGraphDB,
2323
embedder: OllamaEmbedder,
2424
internet_retriever: InternetRetrieverFactory | None = None,

0 commit comments

Comments
 (0)