Skip to content

Commit 3c4fcd4

Browse files
committed
feat: add deepseek llm
1 parent faa4c52 commit 3c4fcd4

File tree

8 files changed

+182
-1
lines changed

8 files changed

+182
-1
lines changed

examples/basic_modules/llm.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@
9494

9595

9696
# Scenario 5: Using LLMFactory with Qwen (DashScope Compatible API)
97+
# Note:
98+
# This example works for any model that supports the OpenAI-compatible Chat Completion API,
99+
# including but not limited to:
100+
# - Qwen models: qwen-plus, qwen-max-2025-01-25
101+
# - DeepSeek models: deepseek-chat, deepseek-coder, deepseek-v3
102+
# - Other compatible providers: MiniMax, Fireworks, Groq, OpenRouter, etc.
103+
#
104+
# Just set the correct `api_key`, `api_base`, and `model_name_or_path`.
97105

98106
config = LLMConfigFactory.model_validate(
99107
{
@@ -111,8 +119,61 @@
111119
)
112120
llm = LLMFactory.from_config(config)
113121
messages = [
114-
{"role": "user", "content": "Can you speak Chinese?"},
122+
{"role": "user", "content": "Hello, who are you"},
115123
]
116124
response = llm.generate(messages)
117125
print("Scenario 5:", response)
118126
print("==" * 20)
127+
128+
print("Scenario 5:\n")
129+
for chunk in llm.generate_stream(messages):
130+
print(chunk, end="")
131+
print("==" * 20)
132+
133+
# Scenario 6: Using LLMFactory with Deepseek-chat
134+
135+
cfg = LLMConfigFactory.model_validate(
136+
{
137+
"backend": "deepseek",
138+
"config": {
139+
"model_name_or_path": "deepseek-chat",
140+
"api_key": "sk-xxx",
141+
"api_base": "https://api.deepseek.com",
142+
"temperature": 0.6,
143+
"max_tokens": 512,
144+
"remove_think_prefix": False,
145+
},
146+
}
147+
)
148+
llm = LLMFactory.from_config(cfg)
149+
messages = [{"role": "user", "content": "Hello, who are you"}]
150+
resp = llm.generate(messages)
151+
print("Scenario 6:", resp)
152+
153+
154+
# Scenario 7: Using LLMFactory with Deepseek-chat + reasoning + CoT + streaming
155+
156+
cfg2 = LLMConfigFactory.model_validate(
157+
{
158+
"backend": "deepseek",
159+
"config": {
160+
"model_name_or_path": "deepseek-reasoner",
161+
"api_key": "sk-xxx",
162+
"api_base": "https://api.deepseek.com",
163+
"temperature": 0.2,
164+
"max_tokens": 1024,
165+
"remove_think_prefix": False,
166+
},
167+
}
168+
)
169+
llm = LLMFactory.from_config(cfg2)
170+
messages = [
171+
{
172+
"role": "user",
173+
"content": "Explain how to solve this problem step-by-step. Be explicit in your thinking process. Question: If a train travels from city A to city B at 60 mph and returns at 40 mph, what is its average speed for the entire trip? Let's think step by step.",
174+
},
175+
]
176+
print("Scenario 7:\n")
177+
for chunk in llm.generate_stream(messages):
178+
print(chunk, end="")
179+
print("==" * 20)

src/memos/configs/llm.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@ class QwenLLMConfig(BaseLLMConfig):
3737
model_name_or_path: str = Field(..., description="Model name for Qwen, e.g., 'qwen-plus'")
3838

3939

40+
class DeepSeekLLMConfig(BaseLLMConfig):
41+
api_key: str = Field(..., description="API key for DeepSeek")
42+
api_base: str = Field(
43+
default="https://api.deepseek.com",
44+
description="Base URL for DeepSeek OpenAI-compatible API",
45+
)
46+
extra_body: Any = Field(default=None, description="Extra options for API")
47+
model_name_or_path: str = Field(
48+
..., description="Model name: 'deepseek-chat' or 'deepseek-reasoner'"
49+
)
50+
51+
4052
class AzureLLMConfig(BaseLLMConfig):
4153
base_url: str = Field(
4254
default="https://api.openai.azure.com/",
@@ -89,6 +101,7 @@ class LLMConfigFactory(BaseConfig):
89101
"vllm": VLLMLLMConfig,
90102
"huggingface_singleton": HFLLMConfig, # Add singleton support
91103
"qwen": QwenLLMConfig,
104+
"deepseek": DeepSeekLLMConfig,
92105
}
93106

94107
@field_validator("backend")

src/memos/llms/base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from collections.abc import Generator
23

34
from memos.configs.llm import BaseLLMConfig
45
from memos.types import MessageList
@@ -14,3 +15,11 @@ def __init__(self, config: BaseLLMConfig):
1415
@abstractmethod
1516
def generate(self, messages: MessageList, **kwargs) -> str:
1617
"""Generate a response from the LLM."""
18+
19+
@abstractmethod
20+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
21+
"""
22+
(Optional) Generate a streaming response from the LLM.
23+
Subclasses should override this if they support streaming.
24+
By default, this raises NotImplementedError.
25+
"""

src/memos/llms/deepseek.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from memos.configs.llm import DeepSeekLLMConfig
2+
from memos.llms.openai import OpenAILLM
3+
from memos.llms.utils import remove_thinking_tags
4+
from memos.log import get_logger
5+
from memos.types import MessageList
6+
7+
8+
logger = get_logger(__name__)
9+
10+
11+
class DeepSeekLLM(OpenAILLM):
12+
"""DeepSeek LLM via OpenAI-compatible API."""
13+
14+
def __init__(self, config: DeepSeekLLMConfig):
15+
super().__init__(config)
16+
17+
def generate(self, messages: MessageList) -> str:
18+
"""Generate a response from DeepSeek."""
19+
response = self.client.chat.completions.create(
20+
model=self.config.model_name_or_path,
21+
messages=messages,
22+
temperature=self.config.temperature,
23+
max_tokens=self.config.max_tokens,
24+
top_p=self.config.top_p,
25+
extra_body=self.config.extra_body,
26+
)
27+
logger.info(f"Response from DeepSeek: {response.model_dump_json()}")
28+
response_content = response.choices[0].message.content
29+
if self.config.remove_think_prefix:
30+
return remove_thinking_tags(response_content)
31+
else:
32+
return response_content
33+
34+
def generate_stream(self, messages: MessageList, **kwargs):
35+
"""Stream response from DeepSeek."""
36+
response = self.client.chat.completions.create(
37+
model=self.config.model_name_or_path,
38+
messages=messages,
39+
stream=True,
40+
temperature=self.config.temperature,
41+
max_tokens=self.config.max_tokens,
42+
top_p=self.config.top_p,
43+
extra_body=self.config.extra_body,
44+
)
45+
# Streaming chunks of text
46+
reasoning_parts = ""
47+
answer_parts = ""
48+
for chunk in response:
49+
delta = chunk.choices[0].delta
50+
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
51+
reasoning_parts += delta.reasoning_content
52+
yield delta.reasoning_content
53+
54+
if hasattr(delta, "content") and delta.content:
55+
answer_parts += delta.content
56+
yield delta.content

src/memos/llms/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from memos.configs.llm import LLMConfigFactory
44
from memos.llms.base import BaseLLM
5+
from memos.llms.deepseek import DeepSeekLLM
56
from memos.llms.hf import HFLLM
67
from memos.llms.hf_singleton import HFSingletonLLM
78
from memos.llms.ollama import OllamaLLM
@@ -21,6 +22,7 @@ class LLMFactory(BaseLLM):
2122
"huggingface_singleton": HFSingletonLLM, # Add singleton version
2223
"vllm": VLLMLLM,
2324
"qwen": QwenLLM,
25+
"deepseek": DeepSeekLLM,
2426
}
2527

2628
@classmethod

src/memos/llms/ollama.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Generator
12
from typing import Any
23

34
from ollama import Client
@@ -80,3 +81,6 @@ def generate(self, messages: MessageList) -> Any:
8081
return remove_thinking_tags(str_response)
8182
else:
8283
return str_response
84+
85+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
86+
raise NotImplementedError

src/memos/llms/openai.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Generator
2+
13
import openai
24

35
from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig
@@ -61,3 +63,6 @@ def generate(self, messages: MessageList) -> str:
6163
return remove_thinking_tags(response_content)
6264
else:
6365
return response_content
66+
67+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
68+
raise NotImplementedError

src/memos/llms/qwen.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections.abc import Generator
2+
13
from memos.configs.llm import QwenLLMConfig
24
from memos.llms.openai import OpenAILLM
35
from memos.llms.utils import remove_thinking_tags
@@ -30,3 +32,32 @@ def generate(self, messages: MessageList) -> str:
3032
return remove_thinking_tags(response_content)
3133
else:
3234
return response_content
35+
36+
def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]:
37+
"""Stream response from Qwen LLM."""
38+
response = self.client.chat.completions.create(
39+
model=self.config.model_name_or_path,
40+
messages=messages,
41+
stream=True,
42+
temperature=self.config.temperature,
43+
max_tokens=self.config.max_tokens,
44+
top_p=self.config.top_p,
45+
extra_body=self.config.extra_body,
46+
)
47+
48+
reasoning_started = False
49+
for chunk in response:
50+
delta = chunk.choices[0].delta
51+
52+
# Some models may have separate `reasoning_content` vs `content`
53+
# For Qwen (DashScope), likely only `content` is used
54+
if hasattr(delta, "reasoning_content") and delta.reasoning_content:
55+
if not reasoning_started and not self.config.remove_think_prefix:
56+
yield "<think>"
57+
reasoning_started = True
58+
yield delta.reasoning_content
59+
elif hasattr(delta, "content") and delta.content:
60+
if reasoning_started and not self.config.remove_think_prefix:
61+
yield "</think>"
62+
reasoning_started = False
63+
yield delta.content

0 commit comments

Comments
 (0)