11import asyncio
2- from typing import Optional , Dict , Any
2+ from typing import Optional , Any , cast
33
44import torch
55from transformers .cache_utils import DynamicCache
@@ -27,117 +27,63 @@ def __init__(self, config: VLLMLLMConfig):
2727
2828 # Initialize OpenAI client for API calls
2929 self .client = None
30- if hasattr (self .config , "api_key" ) and self .config .api_key :
31- import openai
32- self .client = openai .Client (
33- api_key = self .config .api_key ,
34- base_url = getattr (self .config , "api_base" , "http://localhost:8088" )
35- )
36- else :
37- # Create client without API key for local servers
38- import openai
39- self .client = openai .Client (
40- api_key = "dummy" , # vLLM local server doesn't require real API key
41- base_url = getattr (self .config , "api_base" , "http://localhost:8088" )
42- )
30+ api_key = getattr (self .config , "api_key" , "dummy" )
31+ if not api_key :
32+ api_key = "dummy"
33+
34+ import openai
35+ self .client = openai .Client (
36+ api_key = api_key ,
37+ base_url = getattr (self .config , "api_base" , "http://localhost:8088/v1" )
38+ )
4339
44- def build_vllm_kv_cache (self , messages ) -> str :
40+ def build_vllm_kv_cache (self , messages : Any ) -> str :
4541 """
4642 Build a KV cache from chat messages via one vLLM request.
47- Supports the following input types:
48- - str: Used as a system prompt.
49- - list[str]: Concatenated and used as a system prompt.
50- - list[dict]: Used directly as chat messages.
51- The messages are always converted to a standard chat template.
52- Raises:
53- ValueError: If the resulting prompt is empty after template processing.
54- Returns:
55- str: The constructed prompt string for vLLM KV cache building.
43+ Handles str, list[str], and MessageList formats.
5644 """
57- # Accept multiple input types and convert to standard chat messages
58- if isinstance (messages , str ):
59- messages = [
60- {
61- "role" : "system" ,
62- "content" : f"Below is some information about the user.\n { messages } " ,
63- }
64- ]
65- elif isinstance (messages , list ) and messages and isinstance (messages [0 ], str ):
66- # Handle list of strings
67- str_messages = [str (msg ) for msg in messages ]
68- messages = [
69- {
70- "role" : "system" ,
71- "content" : f"Below is some information about the user.\n { ' ' .join (str_messages )} " ,
72- }
73- ]
74-
75- # Convert messages to prompt string using the same logic as HFLLM
76- # Convert to MessageList format for _messages_to_prompt
45+ # 1. Normalize input to a MessageList
46+ processed_messages : MessageList = []
7747 if isinstance (messages , str ):
78- message_list = [{"role" : "system" , "content" : messages }]
79- elif isinstance (messages , list ) and messages and isinstance (messages [0 ], str ):
80- str_messages = [str (msg ) for msg in messages ]
81- message_list = [{"role" : "system" , "content" : " " .join (str_messages )}]
82- else :
83- message_list = messages # Assume it's already in MessageList format
84-
85- # Convert to proper MessageList type
86- from memos .types import MessageList
87- typed_message_list : MessageList = []
88- for msg in message_list :
89- if isinstance (msg , dict ) and "role" in msg and "content" in msg :
90- typed_message_list .append ({
91- "role" : str (msg ["role" ]),
92- "content" : str (msg ["content" ])
93- })
94-
95- prompt = self ._messages_to_prompt (typed_message_list )
48+ processed_messages = [{"role" : "system" , "content" : f"Below is some information about the user.\n { messages } " }]
49+ elif isinstance (messages , list ):
50+ if not messages :
51+ pass # Empty list
52+ elif isinstance (messages [0 ], str ):
53+ str_content = " " .join (str (msg ) for msg in messages )
54+ processed_messages = [{"role" : "system" , "content" : f"Below is some information about the user.\n { str_content } " }]
55+ elif isinstance (messages [0 ], dict ):
56+ processed_messages = cast (MessageList , messages )
57+
58+ # 2. Convert to prompt for logging/return value.
59+ prompt = self ._messages_to_prompt (processed_messages )
9660
9761 if not prompt .strip ():
98- raise ValueError (
99- "Prompt after chat template is empty, cannot build KV cache. Check your messages input."
100- )
62+ raise ValueError ("Prompt is empty, cannot build KV cache." )
10163
102- # Send a request to vLLM server to preload the KV cache
103- # This is done by sending a completion request with max_tokens=0
104- # which will cause vLLM to process the input but not generate any output
105- if self .client is not None :
106- # Convert messages to OpenAI format
107- openai_messages = []
108- for msg in messages :
109- openai_messages .append ({
110- "role" : msg ["role" ],
111- "content" : msg ["content" ]
112- })
113-
114- # Send prefill request to vLLM
64+ # 3. Send request to vLLM server to preload the KV cache
65+ if self .client :
11566 try :
67+ # Use the processed messages for the API call
11668 prefill_kwargs = {
117- "model" : "default" , # vLLM uses "default" as model name
118- "messages" : openai_messages ,
119- "max_tokens" : 2 , # Don't generate any tokens, just prefill
120- "temperature" : 0.0 , # Use deterministic sampling for prefill
69+ "model" : self . config . model_name_or_path ,
70+ "messages" : processed_messages ,
71+ "max_tokens" : 2 ,
72+ "temperature" : 0.0 ,
12173 "top_p" : 1.0 ,
122- "top_k" : 1 ,
12374 }
124- prefill_response = self .client .chat .completions .create (** prefill_kwargs )
125- logger .info (f"vLLM KV cache prefill completed for prompt length: { len ( prompt ) } " )
75+ self .client .chat .completions .create (** prefill_kwargs )
76+ logger .info (f"vLLM KV cache prefill completed for prompt: ' { prompt [: 100 ] } ...' " )
12677 except Exception as e :
12778 logger .warning (f"Failed to prefill vLLM KV cache: { e } " )
128- # Continue anyway, as this is not critical for functionality
12979
13080 return prompt
13181
132- def generate (self , messages : MessageList , past_key_values : Optional [ DynamicCache ] = None ) -> str :
82+ def generate (self , messages : MessageList ) -> str :
13383 """
13484 Generate a response from the model.
135- Args:
136- messages (MessageList): Chat messages for prompt construction.
137- Returns:
138- str: Model response.
13985 """
140- if self .client is not None :
86+ if self .client :
14187 return self ._generate_with_api_client (messages )
14288 else :
14389 raise RuntimeError ("API client is not available" )
@@ -146,60 +92,31 @@ def _generate_with_api_client(self, messages: MessageList) -> str:
14692 """
14793 Generate response using vLLM API client.
14894 """
149- # Convert messages to OpenAI format
150- openai_messages = []
151- for msg in messages :
152- openai_messages .append ({
153- "role" : msg ["role" ],
154- "content" : msg ["content" ]
155- })
156-
157- # Generate response
158- if self .client is not None :
159- # Create completion request with proper parameter types
95+ if self .client :
16096 completion_kwargs = {
161- "model" : "default" , # vLLM uses "default" as model name
162- "messages" : openai_messages ,
97+ "model" : self . config . model_name_or_path ,
98+ "messages" : messages ,
16399 "temperature" : float (getattr (self .config , "temperature" , 0.8 )),
164100 "max_tokens" : int (getattr (self .config , "max_tokens" , 1024 )),
165101 "top_p" : float (getattr (self .config , "top_p" , 0.9 )),
166102 }
167103
168- # Add top_k only if it's greater than 0
169- top_k = getattr (self .config , "top_k" , 50 )
170- if top_k > 0 :
171- completion_kwargs ["top_k" ] = int (top_k )
172-
173104 response = self .client .chat .completions .create (** completion_kwargs )
105+ response_text = response .choices [0 ].message .content or ""
106+ logger .info (f"VLLM API response: { response_text } " )
107+ return remove_thinking_tags (response_text ) if getattr (self .config , "remove_think_prefix" , False ) else response_text
174108 else :
175109 raise RuntimeError ("API client is not available" )
176-
177- response_text = response .choices [0 ].message .content or ""
178- logger .info (f"VLLM API response: { response_text } " )
179-
180- return (
181- remove_thinking_tags (response_text )
182- if getattr (self .config , "remove_think_prefix" , False )
183- else response_text
184- )
185110
186111 def _messages_to_prompt (self , messages : MessageList ) -> str :
187112 """
188113 Convert messages to prompt string.
189114 """
190- # Simple conversion - can be enhanced with proper chat template
191115 prompt_parts = []
192116 for msg in messages :
193117 role = msg ["role" ]
194118 content = msg ["content" ]
195-
196- if role == "system" :
197- prompt_parts .append (f"System: { content } " )
198- elif role == "user" :
199- prompt_parts .append (f"User: { content } " )
200- elif role == "assistant" :
201- prompt_parts .append (f"Assistant: { content } " )
202-
119+ prompt_parts .append (f"{ role .capitalize ()} : { content } " )
203120 return "\n " .join (prompt_parts )
204121
205122
0 commit comments