Skip to content

Commit 54c99fb

Browse files
authored
feat:vllm llm support version0.5 (#68)
## Description <!-- Please include a summary of the changes below; Fill in the issue number that this PR addresses (if applicable); Mention the person who will review this PR (if you know who it is); Replace (summary), (issue), and (reviewer) with the appropriate information (No parentheses). 请在下方填写更改的摘要; 填写此 PR 解决的问题编号(如果适用); 提及将审查此 PR 的人(如果您知道是谁); 替换 (summary)、(issue) 和 (reviewer) 为适当的信息(不带括号)。 --> Summary: (summary) Fix: #(issue) Reviewer: @(reviewer) ## Checklist: - [ ] I have performed a self-review of my own code | 我已自行检查了自己的代码 - [ ] I have commented my code in hard-to-understand areas | 我已在难以理解的地方对代码进行了注释 - [ ] I have added tests that prove my fix is effective or that my feature works | 我已添加测试以证明我的修复有效或功能正常 - [ ] I have added necessary documentation (if applicable) | 我已添加必要的文档(如果适用) - [ ] I have linked the issue to this PR (if applicable) | 我已将 issue 链接到此 PR(如果适用) - [ ] I have mentioned the person who will review this PR | 我已提及将审查此 PR 的人
2 parents a51e1a3 + 370623b commit 54c99fb

File tree

3 files changed

+54
-163
lines changed

3 files changed

+54
-163
lines changed
Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,24 @@
11
#!/usr/bin/env python3
22
"""
3-
Simple example demonstrating how to use VLLMLLM with existing vLLM server.
4-
Requires a vLLM server to be running on localhost:8088.
3+
Simple example demonstrating how to use VLLMLLM with an existing vLLM server.
4+
Requires a vLLM server to be running.
55
"""
66

7-
import asyncio
8-
import sys
9-
107
from memos.configs.llm import VLLMLLMConfig
118
from memos.llms.vllm import VLLMLLM
129
from memos.types import MessageList
1310

14-
1511
def main():
1612
"""Main function demonstrating VLLMLLM usage."""
1713

1814
# Configuration for connecting to existing vLLM server
1915
config = VLLMLLMConfig(
20-
model_name_or_path="Qwen/Qwen3-1.7B", # Model name (for reference)
16+
model_name_or_path="/mnt/afs/models/hf_models/Qwen2.5-7B", # MUST MATCH the --model arg of vLLM server
2117
api_key="", # Not needed for local server
22-
api_base="http://localhost:8088", # vLLM server address
18+
api_base="http://localhost:8088/v1", # vLLM server address with /v1
2319
temperature=0.7,
2420
max_tokens=512,
2521
top_p=0.9,
26-
top_k=50,
2722
model_schema="memos.configs.llm.VLLMLLMConfig",
2823
)
2924

@@ -32,49 +27,28 @@ def main():
3227
llm = VLLMLLM(config)
3328

3429
# Test messages for KV cache building
30+
print("\nBuilding KV cache for system messages...")
3531
system_messages: MessageList = [
3632
{"role": "system", "content": "You are a helpful AI assistant."},
3733
{"role": "user", "content": "Hello! Can you tell me about vLLM?"}
3834
]
39-
40-
# Build KV cache for system messages
41-
print("Building KV cache for system messages...")
4235
try:
4336
prompt = llm.build_vllm_kv_cache(system_messages)
44-
print(f"✓ KV cache built successfully. Prompt length: {len(prompt)}")
37+
print(f"✓ KV cache built successfully for prompt: '{prompt[:100]}...'")
4538
except Exception as e:
4639
print(f"✗ Failed to build KV cache: {e}")
4740

48-
# Test with different messages
41+
# Test with different messages for generation
42+
print("\nGenerating response...")
4943
user_messages: MessageList = [
5044
{"role": "system", "content": "You are a helpful AI assistant."},
5145
{"role": "user", "content": "What are the benefits of using vLLM?"}
5246
]
53-
54-
# Generate response
55-
print("\nGenerating response...")
5647
try:
5748
response = llm.generate(user_messages)
5849
print(f"Response: {response}")
5950
except Exception as e:
6051
print(f"Error generating response: {e}")
61-
62-
# Test with string input for KV cache
63-
print("\nTesting KV cache with string input...")
64-
try:
65-
string_prompt = llm.build_vllm_kv_cache("You are a helpful assistant.")
66-
print(f"✓ String KV cache built successfully. Prompt length: {len(string_prompt)}")
67-
except Exception as e:
68-
print(f"✗ Failed to build string KV cache: {e}")
69-
70-
# Test with list of strings input for KV cache
71-
print("\nTesting KV cache with list of strings input...")
72-
try:
73-
list_prompt = llm.build_vllm_kv_cache(["You are helpful.", "You are knowledgeable."])
74-
print(f"✓ List KV cache built successfully. Prompt length: {len(list_prompt)}")
75-
except Exception as e:
76-
print(f"✗ Failed to build list KV cache: {e}")
77-
7852

7953
if __name__ == "__main__":
8054
main()

src/memos/configs/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class HFLLMConfig(BaseLLMConfig):
4646
class VLLMLLMConfig(BaseLLMConfig):
4747
api_key: str = Field(default="", description="API key for vLLM (optional for local server)")
4848
api_base: str = Field(
49-
default="http://localhost:8088",
49+
default="http://localhost:8088/v1",
5050
description="Base URL for vLLM API",
5151
)
5252

src/memos/llms/vllm.py

Lines changed: 45 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from typing import Optional, Dict, Any
2+
from typing import Optional, Any, cast
33

44
import torch
55
from transformers.cache_utils import DynamicCache
@@ -27,117 +27,63 @@ def __init__(self, config: VLLMLLMConfig):
2727

2828
# Initialize OpenAI client for API calls
2929
self.client = None
30-
if hasattr(self.config, "api_key") and self.config.api_key:
31-
import openai
32-
self.client = openai.Client(
33-
api_key=self.config.api_key,
34-
base_url=getattr(self.config, "api_base", "http://localhost:8088")
35-
)
36-
else:
37-
# Create client without API key for local servers
38-
import openai
39-
self.client = openai.Client(
40-
api_key="dummy", # vLLM local server doesn't require real API key
41-
base_url=getattr(self.config, "api_base", "http://localhost:8088")
42-
)
30+
api_key = getattr(self.config, "api_key", "dummy")
31+
if not api_key:
32+
api_key = "dummy"
33+
34+
import openai
35+
self.client = openai.Client(
36+
api_key=api_key,
37+
base_url=getattr(self.config, "api_base", "http://localhost:8088/v1")
38+
)
4339

44-
def build_vllm_kv_cache(self, messages) -> str:
40+
def build_vllm_kv_cache(self, messages: Any) -> str:
4541
"""
4642
Build a KV cache from chat messages via one vLLM request.
47-
Supports the following input types:
48-
- str: Used as a system prompt.
49-
- list[str]: Concatenated and used as a system prompt.
50-
- list[dict]: Used directly as chat messages.
51-
The messages are always converted to a standard chat template.
52-
Raises:
53-
ValueError: If the resulting prompt is empty after template processing.
54-
Returns:
55-
str: The constructed prompt string for vLLM KV cache building.
43+
Handles str, list[str], and MessageList formats.
5644
"""
57-
# Accept multiple input types and convert to standard chat messages
58-
if isinstance(messages, str):
59-
messages = [
60-
{
61-
"role": "system",
62-
"content": f"Below is some information about the user.\n{messages}",
63-
}
64-
]
65-
elif isinstance(messages, list) and messages and isinstance(messages[0], str):
66-
# Handle list of strings
67-
str_messages = [str(msg) for msg in messages]
68-
messages = [
69-
{
70-
"role": "system",
71-
"content": f"Below is some information about the user.\n{' '.join(str_messages)}",
72-
}
73-
]
74-
75-
# Convert messages to prompt string using the same logic as HFLLM
76-
# Convert to MessageList format for _messages_to_prompt
45+
# 1. Normalize input to a MessageList
46+
processed_messages: MessageList = []
7747
if isinstance(messages, str):
78-
message_list = [{"role": "system", "content": messages}]
79-
elif isinstance(messages, list) and messages and isinstance(messages[0], str):
80-
str_messages = [str(msg) for msg in messages]
81-
message_list = [{"role": "system", "content": " ".join(str_messages)}]
82-
else:
83-
message_list = messages # Assume it's already in MessageList format
84-
85-
# Convert to proper MessageList type
86-
from memos.types import MessageList
87-
typed_message_list: MessageList = []
88-
for msg in message_list:
89-
if isinstance(msg, dict) and "role" in msg and "content" in msg:
90-
typed_message_list.append({
91-
"role": str(msg["role"]),
92-
"content": str(msg["content"])
93-
})
94-
95-
prompt = self._messages_to_prompt(typed_message_list)
48+
processed_messages = [{"role": "system", "content": f"Below is some information about the user.\n{messages}"}]
49+
elif isinstance(messages, list):
50+
if not messages:
51+
pass # Empty list
52+
elif isinstance(messages[0], str):
53+
str_content = " ".join(str(msg) for msg in messages)
54+
processed_messages = [{"role": "system", "content": f"Below is some information about the user.\n{str_content}"}]
55+
elif isinstance(messages[0], dict):
56+
processed_messages = cast(MessageList, messages)
57+
58+
# 2. Convert to prompt for logging/return value.
59+
prompt = self._messages_to_prompt(processed_messages)
9660

9761
if not prompt.strip():
98-
raise ValueError(
99-
"Prompt after chat template is empty, cannot build KV cache. Check your messages input."
100-
)
62+
raise ValueError("Prompt is empty, cannot build KV cache.")
10163

102-
# Send a request to vLLM server to preload the KV cache
103-
# This is done by sending a completion request with max_tokens=0
104-
# which will cause vLLM to process the input but not generate any output
105-
if self.client is not None:
106-
# Convert messages to OpenAI format
107-
openai_messages = []
108-
for msg in messages:
109-
openai_messages.append({
110-
"role": msg["role"],
111-
"content": msg["content"]
112-
})
113-
114-
# Send prefill request to vLLM
64+
# 3. Send request to vLLM server to preload the KV cache
65+
if self.client:
11566
try:
67+
# Use the processed messages for the API call
11668
prefill_kwargs = {
117-
"model": "default", # vLLM uses "default" as model name
118-
"messages": openai_messages,
119-
"max_tokens": 2, # Don't generate any tokens, just prefill
120-
"temperature": 0.0, # Use deterministic sampling for prefill
69+
"model": self.config.model_name_or_path,
70+
"messages": processed_messages,
71+
"max_tokens": 2,
72+
"temperature": 0.0,
12173
"top_p": 1.0,
122-
"top_k": 1,
12374
}
124-
prefill_response = self.client.chat.completions.create(**prefill_kwargs)
125-
logger.info(f"vLLM KV cache prefill completed for prompt length: {len(prompt)}")
75+
self.client.chat.completions.create(**prefill_kwargs)
76+
logger.info(f"vLLM KV cache prefill completed for prompt: '{prompt[:100]}...'")
12677
except Exception as e:
12778
logger.warning(f"Failed to prefill vLLM KV cache: {e}")
128-
# Continue anyway, as this is not critical for functionality
12979

13080
return prompt
13181

132-
def generate(self, messages: MessageList, past_key_values: Optional[DynamicCache] = None) -> str:
82+
def generate(self, messages: MessageList) -> str:
13383
"""
13484
Generate a response from the model.
135-
Args:
136-
messages (MessageList): Chat messages for prompt construction.
137-
Returns:
138-
str: Model response.
13985
"""
140-
if self.client is not None:
86+
if self.client:
14187
return self._generate_with_api_client(messages)
14288
else:
14389
raise RuntimeError("API client is not available")
@@ -146,60 +92,31 @@ def _generate_with_api_client(self, messages: MessageList) -> str:
14692
"""
14793
Generate response using vLLM API client.
14894
"""
149-
# Convert messages to OpenAI format
150-
openai_messages = []
151-
for msg in messages:
152-
openai_messages.append({
153-
"role": msg["role"],
154-
"content": msg["content"]
155-
})
156-
157-
# Generate response
158-
if self.client is not None:
159-
# Create completion request with proper parameter types
95+
if self.client:
16096
completion_kwargs = {
161-
"model": "default", # vLLM uses "default" as model name
162-
"messages": openai_messages,
97+
"model": self.config.model_name_or_path,
98+
"messages": messages,
16399
"temperature": float(getattr(self.config, "temperature", 0.8)),
164100
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
165101
"top_p": float(getattr(self.config, "top_p", 0.9)),
166102
}
167103

168-
# Add top_k only if it's greater than 0
169-
top_k = getattr(self.config, "top_k", 50)
170-
if top_k > 0:
171-
completion_kwargs["top_k"] = int(top_k)
172-
173104
response = self.client.chat.completions.create(**completion_kwargs)
105+
response_text = response.choices[0].message.content or ""
106+
logger.info(f"VLLM API response: {response_text}")
107+
return remove_thinking_tags(response_text) if getattr(self.config, "remove_think_prefix", False) else response_text
174108
else:
175109
raise RuntimeError("API client is not available")
176-
177-
response_text = response.choices[0].message.content or ""
178-
logger.info(f"VLLM API response: {response_text}")
179-
180-
return (
181-
remove_thinking_tags(response_text)
182-
if getattr(self.config, "remove_think_prefix", False)
183-
else response_text
184-
)
185110

186111
def _messages_to_prompt(self, messages: MessageList) -> str:
187112
"""
188113
Convert messages to prompt string.
189114
"""
190-
# Simple conversion - can be enhanced with proper chat template
191115
prompt_parts = []
192116
for msg in messages:
193117
role = msg["role"]
194118
content = msg["content"]
195-
196-
if role == "system":
197-
prompt_parts.append(f"System: {content}")
198-
elif role == "user":
199-
prompt_parts.append(f"User: {content}")
200-
elif role == "assistant":
201-
prompt_parts.append(f"Assistant: {content}")
202-
119+
prompt_parts.append(f"{role.capitalize()}: {content}")
203120
return "\n".join(prompt_parts)
204121

205122

0 commit comments

Comments
 (0)