Skip to content

Commit 036a99f

Browse files
committed
feat:vllm llm support version0
1 parent 97fdb06 commit 036a99f

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Simple example demonstrating how to use VLLMLLM with existing vLLM server.
4+
Requires a vLLM server to be running on localhost:8088.
5+
"""
6+
7+
import asyncio
8+
import sys
9+
10+
from memos.configs.llm import VLLMLLMConfig
11+
from memos.llms.vllm import VLLMLLM
12+
from memos.types import MessageList
13+
14+
15+
def main():
16+
"""Main function demonstrating VLLMLLM usage."""
17+
18+
# Configuration for connecting to existing vLLM server
19+
config = VLLMLLMConfig(
20+
model_name_or_path="Qwen/Qwen3-1.7B", # Model name (for reference)
21+
api_key="", # Not needed for local server
22+
api_base="http://localhost:8088", # vLLM server address
23+
temperature=0.7,
24+
max_tokens=512,
25+
top_p=0.9,
26+
top_k=50,
27+
model_schema="memos.configs.llm.VLLMLLMConfig",
28+
)
29+
30+
# Initialize VLLM LLM
31+
print("Initializing VLLM LLM...")
32+
llm = VLLMLLM(config)
33+
34+
# Test messages for KV cache building
35+
system_messages: MessageList = [
36+
{"role": "system", "content": "You are a helpful AI assistant."},
37+
{"role": "user", "content": "Hello! Can you tell me about vLLM?"}
38+
]
39+
40+
# Build KV cache for system messages
41+
print("Building KV cache for system messages...")
42+
try:
43+
prompt = llm.build_vllm_kv_cache(system_messages)
44+
print(f"✓ KV cache built successfully. Prompt length: {len(prompt)}")
45+
except Exception as e:
46+
print(f"✗ Failed to build KV cache: {e}")
47+
48+
# Test with different messages
49+
user_messages: MessageList = [
50+
{"role": "system", "content": "You are a helpful AI assistant."},
51+
{"role": "user", "content": "What are the benefits of using vLLM?"}
52+
]
53+
54+
# Generate response
55+
print("\nGenerating response...")
56+
try:
57+
response = llm.generate(user_messages)
58+
print(f"Response: {response}")
59+
except Exception as e:
60+
print(f"Error generating response: {e}")
61+
62+
# Test with string input for KV cache
63+
print("\nTesting KV cache with string input...")
64+
try:
65+
string_prompt = llm.build_vllm_kv_cache("You are a helpful assistant.")
66+
print(f"✓ String KV cache built successfully. Prompt length: {len(string_prompt)}")
67+
except Exception as e:
68+
print(f"✗ Failed to build string KV cache: {e}")
69+
70+
# Test with list of strings input for KV cache
71+
print("\nTesting KV cache with list of strings input...")
72+
try:
73+
list_prompt = llm.build_vllm_kv_cache(["You are helpful.", "You are knowledgeable."])
74+
print(f"✓ List KV cache built successfully. Prompt length: {len(list_prompt)}")
75+
except Exception as e:
76+
print(f"✗ Failed to build list KV cache: {e}")
77+
78+
79+
if __name__ == "__main__":
80+
main()

src/memos/configs/llm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ class HFLLMConfig(BaseLLMConfig):
4343
description="Apply generation template for the conversation",
4444
)
4545

46+
class VLLMLLMConfig(BaseLLMConfig):
47+
api_key: str = Field(default="", description="API key for vLLM (optional for local server)")
48+
api_base: str = Field(
49+
default="http://localhost:8088",
50+
description="Base URL for vLLM API",
51+
)
4652

4753
class LLMConfigFactory(BaseConfig):
4854
"""Factory class for creating LLM configurations."""
@@ -54,6 +60,7 @@ class LLMConfigFactory(BaseConfig):
5460
"openai": OpenAILLMConfig,
5561
"ollama": OllamaLLMConfig,
5662
"huggingface": HFLLMConfig,
63+
"vllm": VLLMLLMConfig,
5764
}
5865

5966
@field_validator("backend")

src/memos/llms/vllm.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import asyncio
2+
from typing import Optional, Dict, Any
3+
4+
import torch
5+
from transformers.cache_utils import DynamicCache
6+
7+
from memos.configs.llm import VLLMLLMConfig
8+
from memos.llms.base import BaseLLM
9+
from memos.llms.utils import remove_thinking_tags
10+
from memos.log import get_logger
11+
from memos.types import MessageList
12+
13+
14+
logger = get_logger(__name__)
15+
16+
17+
class VLLMLLM(BaseLLM):
18+
"""
19+
VLLM LLM class for connecting to existing vLLM servers.
20+
"""
21+
22+
def __init__(self, config: VLLMLLMConfig):
23+
"""
24+
Initialize the VLLM LLM to connect to an existing vLLM server.
25+
"""
26+
self.config = config
27+
28+
# Initialize OpenAI client for API calls
29+
self.client = None
30+
if hasattr(self.config, "api_key") and self.config.api_key:
31+
import openai
32+
self.client = openai.Client(
33+
api_key=self.config.api_key,
34+
base_url=getattr(self.config, "api_base", "http://localhost:8088")
35+
)
36+
else:
37+
# Create client without API key for local servers
38+
import openai
39+
self.client = openai.Client(
40+
api_key="dummy", # vLLM local server doesn't require real API key
41+
base_url=getattr(self.config, "api_base", "http://localhost:8088")
42+
)
43+
44+
def build_vllm_kv_cache(self, messages) -> str:
45+
"""
46+
Build a KV cache from chat messages via one vLLM request.
47+
Supports the following input types:
48+
- str: Used as a system prompt.
49+
- list[str]: Concatenated and used as a system prompt.
50+
- list[dict]: Used directly as chat messages.
51+
The messages are always converted to a standard chat template.
52+
Raises:
53+
ValueError: If the resulting prompt is empty after template processing.
54+
Returns:
55+
str: The constructed prompt string for vLLM KV cache building.
56+
"""
57+
# Accept multiple input types and convert to standard chat messages
58+
if isinstance(messages, str):
59+
messages = [
60+
{
61+
"role": "system",
62+
"content": f"Below is some information about the user.\n{messages}",
63+
}
64+
]
65+
elif isinstance(messages, list) and messages and isinstance(messages[0], str):
66+
# Handle list of strings
67+
str_messages = [str(msg) for msg in messages]
68+
messages = [
69+
{
70+
"role": "system",
71+
"content": f"Below is some information about the user.\n{' '.join(str_messages)}",
72+
}
73+
]
74+
75+
# Convert messages to prompt string using the same logic as HFLLM
76+
# Convert to MessageList format for _messages_to_prompt
77+
if isinstance(messages, str):
78+
message_list = [{"role": "system", "content": messages}]
79+
elif isinstance(messages, list) and messages and isinstance(messages[0], str):
80+
str_messages = [str(msg) for msg in messages]
81+
message_list = [{"role": "system", "content": " ".join(str_messages)}]
82+
else:
83+
message_list = messages # Assume it's already in MessageList format
84+
85+
# Convert to proper MessageList type
86+
from memos.types import MessageList
87+
typed_message_list: MessageList = []
88+
for msg in message_list:
89+
if isinstance(msg, dict) and "role" in msg and "content" in msg:
90+
typed_message_list.append({
91+
"role": str(msg["role"]),
92+
"content": str(msg["content"])
93+
})
94+
95+
prompt = self._messages_to_prompt(typed_message_list)
96+
97+
if not prompt.strip():
98+
raise ValueError(
99+
"Prompt after chat template is empty, cannot build KV cache. Check your messages input."
100+
)
101+
102+
# Send a request to vLLM server to preload the KV cache
103+
# This is done by sending a completion request with max_tokens=0
104+
# which will cause vLLM to process the input but not generate any output
105+
if self.client is not None:
106+
# Convert messages to OpenAI format
107+
openai_messages = []
108+
for msg in messages:
109+
openai_messages.append({
110+
"role": msg["role"],
111+
"content": msg["content"]
112+
})
113+
114+
# Send prefill request to vLLM
115+
try:
116+
prefill_kwargs = {
117+
"model": "default", # vLLM uses "default" as model name
118+
"messages": openai_messages,
119+
"max_tokens": 2, # Don't generate any tokens, just prefill
120+
"temperature": 0.0, # Use deterministic sampling for prefill
121+
"top_p": 1.0,
122+
"top_k": 1,
123+
}
124+
prefill_response = self.client.chat.completions.create(**prefill_kwargs)
125+
logger.info(f"vLLM KV cache prefill completed for prompt length: {len(prompt)}")
126+
except Exception as e:
127+
logger.warning(f"Failed to prefill vLLM KV cache: {e}")
128+
# Continue anyway, as this is not critical for functionality
129+
130+
return prompt
131+
132+
def generate(self, messages: MessageList, past_key_values: Optional[DynamicCache] = None) -> str:
133+
"""
134+
Generate a response from the model.
135+
Args:
136+
messages (MessageList): Chat messages for prompt construction.
137+
Returns:
138+
str: Model response.
139+
"""
140+
if self.client is not None:
141+
return self._generate_with_api_client(messages)
142+
else:
143+
raise RuntimeError("API client is not available")
144+
145+
def _generate_with_api_client(self, messages: MessageList) -> str:
146+
"""
147+
Generate response using vLLM API client.
148+
"""
149+
# Convert messages to OpenAI format
150+
openai_messages = []
151+
for msg in messages:
152+
openai_messages.append({
153+
"role": msg["role"],
154+
"content": msg["content"]
155+
})
156+
157+
# Generate response
158+
if self.client is not None:
159+
# Create completion request with proper parameter types
160+
completion_kwargs = {
161+
"model": "default", # vLLM uses "default" as model name
162+
"messages": openai_messages,
163+
"temperature": float(getattr(self.config, "temperature", 0.8)),
164+
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
165+
"top_p": float(getattr(self.config, "top_p", 0.9)),
166+
}
167+
168+
# Add top_k only if it's greater than 0
169+
top_k = getattr(self.config, "top_k", 50)
170+
if top_k > 0:
171+
completion_kwargs["top_k"] = int(top_k)
172+
173+
response = self.client.chat.completions.create(**completion_kwargs)
174+
else:
175+
raise RuntimeError("API client is not available")
176+
177+
response_text = response.choices[0].message.content or ""
178+
logger.info(f"VLLM API response: {response_text}")
179+
180+
return (
181+
remove_thinking_tags(response_text)
182+
if getattr(self.config, "remove_think_prefix", False)
183+
else response_text
184+
)
185+
186+
def _messages_to_prompt(self, messages: MessageList) -> str:
187+
"""
188+
Convert messages to prompt string.
189+
"""
190+
# Simple conversion - can be enhanced with proper chat template
191+
prompt_parts = []
192+
for msg in messages:
193+
role = msg["role"]
194+
content = msg["content"]
195+
196+
if role == "system":
197+
prompt_parts.append(f"System: {content}")
198+
elif role == "user":
199+
prompt_parts.append(f"User: {content}")
200+
elif role == "assistant":
201+
prompt_parts.append(f"Assistant: {content}")
202+
203+
return "\n".join(prompt_parts)
204+
205+

0 commit comments

Comments
 (0)