|
| 1 | +from memos.configs.llm import DeepSeekLLMConfig |
| 2 | +from memos.llms.openai import OpenAILLM |
| 3 | +from memos.llms.utils import remove_thinking_tags |
| 4 | +from memos.log import get_logger |
| 5 | +from memos.types import MessageList |
| 6 | + |
| 7 | + |
| 8 | +logger = get_logger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +class DeepSeekLLM(OpenAILLM): |
| 12 | + """DeepSeek LLM via OpenAI-compatible API.""" |
| 13 | + |
| 14 | + def __init__(self, config: DeepSeekLLMConfig): |
| 15 | + super().__init__(config) |
| 16 | + |
| 17 | + def generate(self, messages: MessageList) -> str: |
| 18 | + """Generate a response from DeepSeek.""" |
| 19 | + response = self.client.chat.completions.create( |
| 20 | + model=self.config.model_name_or_path, |
| 21 | + messages=messages, |
| 22 | + temperature=self.config.temperature, |
| 23 | + max_tokens=self.config.max_tokens, |
| 24 | + top_p=self.config.top_p, |
| 25 | + extra_body=self.config.extra_body, |
| 26 | + ) |
| 27 | + logger.info(f"Response from DeepSeek: {response.model_dump_json()}") |
| 28 | + response_content = response.choices[0].message.content |
| 29 | + if self.config.remove_think_prefix: |
| 30 | + return remove_thinking_tags(response_content) |
| 31 | + else: |
| 32 | + return response_content |
| 33 | + |
| 34 | + def generate_stream(self, messages: MessageList, **kwargs): |
| 35 | + """Stream response from DeepSeek.""" |
| 36 | + response = self.client.chat.completions.create( |
| 37 | + model=self.config.model_name_or_path, |
| 38 | + messages=messages, |
| 39 | + stream=True, |
| 40 | + temperature=self.config.temperature, |
| 41 | + max_tokens=self.config.max_tokens, |
| 42 | + top_p=self.config.top_p, |
| 43 | + extra_body=self.config.extra_body, |
| 44 | + ) |
| 45 | + # Streaming chunks of text |
| 46 | + reasoning_parts = "" |
| 47 | + answer_parts = "" |
| 48 | + for chunk in response: |
| 49 | + delta = chunk.choices[0].delta |
| 50 | + if hasattr(delta, "reasoning_content") and delta.reasoning_content: |
| 51 | + reasoning_parts += delta.reasoning_content |
| 52 | + yield delta.reasoning_content |
| 53 | + |
| 54 | + if hasattr(delta, "content") and delta.content: |
| 55 | + answer_parts += delta.content |
| 56 | + yield delta.content |
0 commit comments