|
| 1 | +import asyncio |
| 2 | +from typing import Optional, Dict, Any |
| 3 | + |
| 4 | +import torch |
| 5 | +from transformers.cache_utils import DynamicCache |
| 6 | + |
| 7 | +from memos.configs.llm import VLLMLLMConfig |
| 8 | +from memos.llms.base import BaseLLM |
| 9 | +from memos.llms.utils import remove_thinking_tags |
| 10 | +from memos.log import get_logger |
| 11 | +from memos.types import MessageList |
| 12 | + |
| 13 | + |
| 14 | +logger = get_logger(__name__) |
| 15 | + |
| 16 | + |
| 17 | +class VLLMLLM(BaseLLM): |
| 18 | + """ |
| 19 | + VLLM LLM class for connecting to existing vLLM servers. |
| 20 | + """ |
| 21 | + |
| 22 | + def __init__(self, config: VLLMLLMConfig): |
| 23 | + """ |
| 24 | + Initialize the VLLM LLM to connect to an existing vLLM server. |
| 25 | + """ |
| 26 | + self.config = config |
| 27 | + |
| 28 | + # Initialize OpenAI client for API calls |
| 29 | + self.client = None |
| 30 | + if hasattr(self.config, "api_key") and self.config.api_key: |
| 31 | + import openai |
| 32 | + self.client = openai.Client( |
| 33 | + api_key=self.config.api_key, |
| 34 | + base_url=getattr(self.config, "api_base", "http://localhost:8088") |
| 35 | + ) |
| 36 | + else: |
| 37 | + # Create client without API key for local servers |
| 38 | + import openai |
| 39 | + self.client = openai.Client( |
| 40 | + api_key="dummy", # vLLM local server doesn't require real API key |
| 41 | + base_url=getattr(self.config, "api_base", "http://localhost:8088") |
| 42 | + ) |
| 43 | + |
| 44 | + def build_vllm_kv_cache(self, messages) -> str: |
| 45 | + """ |
| 46 | + Build a KV cache from chat messages via one vLLM request. |
| 47 | + Supports the following input types: |
| 48 | + - str: Used as a system prompt. |
| 49 | + - list[str]: Concatenated and used as a system prompt. |
| 50 | + - list[dict]: Used directly as chat messages. |
| 51 | + The messages are always converted to a standard chat template. |
| 52 | + Raises: |
| 53 | + ValueError: If the resulting prompt is empty after template processing. |
| 54 | + Returns: |
| 55 | + str: The constructed prompt string for vLLM KV cache building. |
| 56 | + """ |
| 57 | + # Accept multiple input types and convert to standard chat messages |
| 58 | + if isinstance(messages, str): |
| 59 | + messages = [ |
| 60 | + { |
| 61 | + "role": "system", |
| 62 | + "content": f"Below is some information about the user.\n{messages}", |
| 63 | + } |
| 64 | + ] |
| 65 | + elif isinstance(messages, list) and messages and isinstance(messages[0], str): |
| 66 | + # Handle list of strings |
| 67 | + str_messages = [str(msg) for msg in messages] |
| 68 | + messages = [ |
| 69 | + { |
| 70 | + "role": "system", |
| 71 | + "content": f"Below is some information about the user.\n{' '.join(str_messages)}", |
| 72 | + } |
| 73 | + ] |
| 74 | + |
| 75 | + # Convert messages to prompt string using the same logic as HFLLM |
| 76 | + # Convert to MessageList format for _messages_to_prompt |
| 77 | + if isinstance(messages, str): |
| 78 | + message_list = [{"role": "system", "content": messages}] |
| 79 | + elif isinstance(messages, list) and messages and isinstance(messages[0], str): |
| 80 | + str_messages = [str(msg) for msg in messages] |
| 81 | + message_list = [{"role": "system", "content": " ".join(str_messages)}] |
| 82 | + else: |
| 83 | + message_list = messages # Assume it's already in MessageList format |
| 84 | + |
| 85 | + # Convert to proper MessageList type |
| 86 | + from memos.types import MessageList |
| 87 | + typed_message_list: MessageList = [] |
| 88 | + for msg in message_list: |
| 89 | + if isinstance(msg, dict) and "role" in msg and "content" in msg: |
| 90 | + typed_message_list.append({ |
| 91 | + "role": str(msg["role"]), |
| 92 | + "content": str(msg["content"]) |
| 93 | + }) |
| 94 | + |
| 95 | + prompt = self._messages_to_prompt(typed_message_list) |
| 96 | + |
| 97 | + if not prompt.strip(): |
| 98 | + raise ValueError( |
| 99 | + "Prompt after chat template is empty, cannot build KV cache. Check your messages input." |
| 100 | + ) |
| 101 | + |
| 102 | + # Send a request to vLLM server to preload the KV cache |
| 103 | + # This is done by sending a completion request with max_tokens=0 |
| 104 | + # which will cause vLLM to process the input but not generate any output |
| 105 | + if self.client is not None: |
| 106 | + # Convert messages to OpenAI format |
| 107 | + openai_messages = [] |
| 108 | + for msg in messages: |
| 109 | + openai_messages.append({ |
| 110 | + "role": msg["role"], |
| 111 | + "content": msg["content"] |
| 112 | + }) |
| 113 | + |
| 114 | + # Send prefill request to vLLM |
| 115 | + try: |
| 116 | + prefill_kwargs = { |
| 117 | + "model": "default", # vLLM uses "default" as model name |
| 118 | + "messages": openai_messages, |
| 119 | + "max_tokens": 2, # Don't generate any tokens, just prefill |
| 120 | + "temperature": 0.0, # Use deterministic sampling for prefill |
| 121 | + "top_p": 1.0, |
| 122 | + "top_k": 1, |
| 123 | + } |
| 124 | + prefill_response = self.client.chat.completions.create(**prefill_kwargs) |
| 125 | + logger.info(f"vLLM KV cache prefill completed for prompt length: {len(prompt)}") |
| 126 | + except Exception as e: |
| 127 | + logger.warning(f"Failed to prefill vLLM KV cache: {e}") |
| 128 | + # Continue anyway, as this is not critical for functionality |
| 129 | + |
| 130 | + return prompt |
| 131 | + |
| 132 | + def generate(self, messages: MessageList, past_key_values: Optional[DynamicCache] = None) -> str: |
| 133 | + """ |
| 134 | + Generate a response from the model. |
| 135 | + Args: |
| 136 | + messages (MessageList): Chat messages for prompt construction. |
| 137 | + Returns: |
| 138 | + str: Model response. |
| 139 | + """ |
| 140 | + if self.client is not None: |
| 141 | + return self._generate_with_api_client(messages) |
| 142 | + else: |
| 143 | + raise RuntimeError("API client is not available") |
| 144 | + |
| 145 | + def _generate_with_api_client(self, messages: MessageList) -> str: |
| 146 | + """ |
| 147 | + Generate response using vLLM API client. |
| 148 | + """ |
| 149 | + # Convert messages to OpenAI format |
| 150 | + openai_messages = [] |
| 151 | + for msg in messages: |
| 152 | + openai_messages.append({ |
| 153 | + "role": msg["role"], |
| 154 | + "content": msg["content"] |
| 155 | + }) |
| 156 | + |
| 157 | + # Generate response |
| 158 | + if self.client is not None: |
| 159 | + # Create completion request with proper parameter types |
| 160 | + completion_kwargs = { |
| 161 | + "model": "default", # vLLM uses "default" as model name |
| 162 | + "messages": openai_messages, |
| 163 | + "temperature": float(getattr(self.config, "temperature", 0.8)), |
| 164 | + "max_tokens": int(getattr(self.config, "max_tokens", 1024)), |
| 165 | + "top_p": float(getattr(self.config, "top_p", 0.9)), |
| 166 | + } |
| 167 | + |
| 168 | + # Add top_k only if it's greater than 0 |
| 169 | + top_k = getattr(self.config, "top_k", 50) |
| 170 | + if top_k > 0: |
| 171 | + completion_kwargs["top_k"] = int(top_k) |
| 172 | + |
| 173 | + response = self.client.chat.completions.create(**completion_kwargs) |
| 174 | + else: |
| 175 | + raise RuntimeError("API client is not available") |
| 176 | + |
| 177 | + response_text = response.choices[0].message.content or "" |
| 178 | + logger.info(f"VLLM API response: {response_text}") |
| 179 | + |
| 180 | + return ( |
| 181 | + remove_thinking_tags(response_text) |
| 182 | + if getattr(self.config, "remove_think_prefix", False) |
| 183 | + else response_text |
| 184 | + ) |
| 185 | + |
| 186 | + def _messages_to_prompt(self, messages: MessageList) -> str: |
| 187 | + """ |
| 188 | + Convert messages to prompt string. |
| 189 | + """ |
| 190 | + # Simple conversion - can be enhanced with proper chat template |
| 191 | + prompt_parts = [] |
| 192 | + for msg in messages: |
| 193 | + role = msg["role"] |
| 194 | + content = msg["content"] |
| 195 | + |
| 196 | + if role == "system": |
| 197 | + prompt_parts.append(f"System: {content}") |
| 198 | + elif role == "user": |
| 199 | + prompt_parts.append(f"User: {content}") |
| 200 | + elif role == "assistant": |
| 201 | + prompt_parts.append(f"Assistant: {content}") |
| 202 | + |
| 203 | + return "\n".join(prompt_parts) |
| 204 | + |
| 205 | + |
0 commit comments