11import torch
22import logging
33import warnings
4- from contextlib import contextmanager
54
65from genlm .backend .llm .base import AsyncLM
76from genlm .backend .cache import OutputCache
109 from vllm import AsyncLLMEngine , SamplingParams , AsyncEngineArgs
1110 from vllm .utils import Counter
1211 from vllm .inputs import TokensPrompt
13- from vllm .model_executor .layers .sampler import SamplerOutput
14- from vllm .sequence import SequenceOutput , CompletionSequenceGroupOutput , Logprob
1512
1613 from vllm .distributed .parallel_state import (
1714 destroy_model_parallel ,
@@ -43,16 +40,27 @@ def from_name(cls, *args, **kwargs): # pragma: no cover
4340else :
4441 logging .getLogger ("vllm.engine.async_llm_engine" ).setLevel (logging .WARNING )
4542
46- class AsyncVirtualLM (AsyncLM ):
47- """A wrapper around vLLM's `AsyncLLMEngine` for asynchronous next token log probability computations.
43+ class PassThroughLogitsProcessor :
44+ """A logits processor that stores the logprobs and passes the logits through."""
45+
46+ def __init__ (self ):
47+ self .log_probs = None
4848
49- This class provides an asynchronous interface for computing log probabilities using vLLM's engine.
50- It is optimized for next token log probability computations and supports caching of results (outputs and KV).
51- """
49+ def __call__ (self , past_token_ids , logits ):
50+ assert self .log_probs is None , (
51+ "Log probs already set. This should never happen."
52+ )
53+ self .log_probs = torch .log_softmax (logits , dim = - 1 , dtype = logits .dtype )
54+ return logits
5255
53- default_params = SamplingParams (
54- max_tokens = 1 , n = 1 , logprobs = 1 , detokenize = False , stop = None , ignore_eos = True
55- )
56+ class AsyncVirtualLM (AsyncLM ):
57+ default_params = {
58+ "max_tokens" : 1 ,
59+ "n" : 1 ,
60+ "detokenize" : False ,
61+ "stop" : None ,
62+ "ignore_eos" : True ,
63+ }
5664
5765 def __init__ (self , async_llm_engine , cache_size = 0 , cache_opts = {}):
5866 """Initialize an `AsyncVirtualLM` instance.
@@ -68,8 +76,6 @@ def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
6876 self .async_llm_engine = async_llm_engine
6977 self .tokenizer = async_llm_engine .engine .get_tokenizer ()
7078 self .request_counter = Counter ()
71- self .custom_sampler = DeferredSampler ()
72- self .original_sampler = self .underlying_model .sampler
7379 self .cache = (
7480 OutputCache (maxsize = cache_size , ** cache_opts )
7581 if cache_size > 0
@@ -108,10 +114,7 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):
108114 engine_opts = {
109115 "enable_prefix_caching" : True ,
110116 "disable_log_requests" : True ,
111- "disable_async_output_proc" : True ,
112- # Need to disable chunked prefill to avoid issues
113- # with our custom sampler.
114- "enable_chunked_prefill" : False ,
117+ "disable_async_output_proc" : True , # This parameter forces vLLM to use v0, which is currently what we want to do.
115118 ** (engine_opts or {}),
116119 }
117120
@@ -163,16 +166,21 @@ async def _next_token_logprobs(self, token_ids):
163166 prompt = TokensPrompt (prompt_token_ids = token_ids )
164167
165168 outputs = []
166- with self ._temporarily_set_sampler (self .custom_sampler ):
167- async for output in self .async_llm_engine .generate (
168- prompt = prompt ,
169- sampling_params = self .default_params ,
170- request_id = req_id ,
171- ):
172- if output .finished :
173- outputs .append (output )
174-
175- return self ._validate_outputs (outputs )
169+ processor = PassThroughLogitsProcessor ()
170+ async for output in self .async_llm_engine .generate (
171+ prompt = prompt ,
172+ sampling_params = SamplingParams (
173+ ** self .default_params , logits_processors = [processor ]
174+ ),
175+ request_id = req_id ,
176+ ):
177+ if output .finished :
178+ outputs .append (output )
179+
180+ assert processor .log_probs is not None , (
181+ "Log probs should be set by the logits processor."
182+ )
183+ return processor .log_probs
176184
177185 def next_token_logprobs_sync (self , token_ids ):
178186 """Request log probabilities of next token synchronously.
@@ -196,69 +204,31 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
196204 (torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
197205 """
198206 req_ids = []
207+ req_id2processors = {}
199208 for token_ids in token_ids_list :
200209 req_id = str (next (self .request_counter ))
201210 req_ids .append (req_id )
211+ processor = PassThroughLogitsProcessor ()
212+ req_id2processors [req_id ] = processor
202213 self .async_llm_engine .engine .add_request (
203214 prompt = TokensPrompt (prompt_token_ids = token_ids ),
204- params = self .default_params ,
215+ params = SamplingParams (
216+ ** self .default_params , logits_processors = [processor ]
217+ ),
205218 request_id = req_id ,
206219 )
207220
208- req_id2outputs = {}
209- with self ._temporarily_set_sampler (self .custom_sampler ):
210- while self .async_llm_engine .engine .has_unfinished_requests ():
211- output = self .async_llm_engine .engine .step ()
212- for out in output :
213- if out .finished :
214- assert out .request_id not in req_id2outputs , (
215- f"Duplicate outputs for request { out .request_id } "
216- )
217- assert out .request_id in req_ids , (
218- f"{ out .request_id } not in requested IDs"
219- )
220- req_id2outputs [out .request_id ] = out
221-
222- logprobs = [
223- self ._validate_outputs ([req_id2outputs [req_id ]]) for req_id in req_ids
224- ]
225-
226- return torch .stack (logprobs )
227-
228- @contextmanager
229- def _temporarily_set_sampler (self , sampler ):
230- """Context manager for temporarily setting a custom sampler."""
231- original_sampler = self .underlying_model .sampler
232- try :
233- self .underlying_model .sampler = sampler
234- yield
235- finally :
236- self .underlying_model .sampler = original_sampler
237-
238- def _validate_outputs (self , outputs ):
239- """Validate and extract logprobs from a vLLM output.
240-
241- Args:
242- outputs: List of sequence group outputs from vLLM generation
243-
244- Returns:
245- Tensor of log probabilities for the next token
246-
247- Raises:
248- AssertionError: If output structure doesn't match expected format
249- """
250- assert len (outputs ) == 1 , "Expected exactly one sequence group"
251- seq_group = outputs [0 ]
221+ while self .async_llm_engine .engine .has_unfinished_requests ():
222+ output = self .async_llm_engine .engine .step ()
223+ for out in output :
224+ if out .finished :
225+ assert out .request_id in req_id2processors , (
226+ f"{ out .request_id } not in requested IDs"
227+ )
252228
253- assert len ( seq_group . outputs ) == 1 , (
254- "Expected exactly one sequence in output"
229+ return torch . stack (
230+ [ req_id2processors [ req_id ]. log_probs for req_id in req_ids ]
255231 )
256- sequence = seq_group .outputs [0 ]
257-
258- assert len (sequence .logprobs ) == 1 , "Expected exactly one set of logprobs"
259- token_logprobs = sequence .logprobs [0 ].logprobs
260-
261- return token_logprobs
262232
263233 def clear_cache (self ):
264234 """Clear output cache."""
@@ -296,141 +266,22 @@ async def sample(
296266 Returns:
297267 (list[int]): The sampled token IDs.
298268 """
299- with self ._temporarily_set_sampler (self .original_sampler ):
300- async for output in self .async_llm_engine .generate (
301- prompt = TokensPrompt (prompt_token_ids = prompt_token_ids ),
302- sampling_params = SamplingParams (
303- n = 1 ,
304- max_tokens = max_tokens ,
305- temperature = temperature ,
306- seed = seed ,
307- stop = [self .byte_vocab [i ].decode () for i in eos_token_ids ],
308- ),
309- request_id = str (next (self .request_counter )),
310- ):
311- if output .finished :
312- assert len (output .outputs ) == 1 , (
313- "Expected exactly one sequence group"
314- )
315- token_ids = list (output .outputs [0 ].token_ids )
316- if token_ids [- 1 ] in eos_token_ids :
317- token_ids = token_ids [:- 1 ]
318- return token_ids
319-
320-
321- class DeferredSampler (torch .nn .Module ):
322- """A custom vLLM sampler optimized for efficient next-token probability calculations.
323-
324- This sampler replaces vLLM's default sampling mechanism to optimize for scenarios
325- where we only need the next token probabilities without actually sampling tokens.
326-
327- Note:
328- While this sampler implements vLLM's expected interface, it intentionally
329- avoids actual token sampling to optimize for probability calculation use cases.
330- It should not be used in scenarios where actual token generation is needed.
331- """
332-
333- def __init__ (self ):
334- super ().__init__ ()
335-
336- def forward (self , logits , sampling_metadata ):
337- """Process model logits to create vLLM-compatible sampling outputs.
338-
339- This method implements the required vLLM sampler interface but optimizes for
340- probability requests.
341-
342- Args:
343- logits (torch.Tensor): Raw model logits with shape (num_tokens, vocab_size).
344- sampling_metadata: vLLM metadata containing sequence grouping information.
345-
346- Returns:
347- SamplerOutput: A vLLM-compatible output structure containing:
348- - Sequence group outputs with lazy probability dictionaries
349- - Placeholder values for unused sampling fields
350- - No actual sampled tokens (uses dummy token_id=0)
351-
352- Note:
353- The sampler uses token_id=0 as a placeholder.
354- """
355- assert logits is not None
356-
357- logprobs = logits .log_softmax (dim = - 1 , dtype = torch .float )
358-
359- sample_idx = 0
360- sampler_output = []
361- for seq_group in sampling_metadata .seq_groups :
362- seq_ids = seq_group .seq_ids
363- num_parent_seqs = len (seq_ids )
364- logprobs_by_seq = logprobs [sample_idx : sample_idx + num_parent_seqs ]
365-
366- if not seq_group .do_sample :
367- sampler_output .append (
368- CompletionSequenceGroupOutput (samples = [], prompt_logprobs = [])
369- )
370- else :
371- assert len (logprobs_by_seq ) == len (seq_ids )
372- seq_outputs = []
373- for seq_id , seq_logprobs in zip (seq_ids , logprobs_by_seq ):
374- seq_outputs .append (
375- SequenceOutput (seq_id , 0 , LazyLogprobDict (seq_logprobs ))
269+ async for output in self .async_llm_engine .generate (
270+ prompt = TokensPrompt (prompt_token_ids = prompt_token_ids ),
271+ sampling_params = SamplingParams (
272+ n = 1 ,
273+ max_tokens = max_tokens ,
274+ temperature = temperature ,
275+ seed = seed ,
276+ stop = [self .byte_vocab [i ].decode () for i in eos_token_ids ],
277+ ),
278+ request_id = str (next (self .request_counter )),
279+ ):
280+ if output .finished :
281+ assert len (output .outputs ) == 1 , (
282+ "Expected exactly one sequence group"
376283 )
377-
378- sampler_output .append (
379- CompletionSequenceGroupOutput (
380- samples = seq_outputs , prompt_logprobs = []
381- )
382- )
383-
384- sample_idx += 1
385-
386- sampler_outputs = SamplerOutput (
387- outputs = sampler_output ,
388- sampled_token_probs = None ,
389- sampled_token_ids = None ,
390- logprobs = None ,
391- deferred_sample_results_args = None ,
392- )
393-
394- return sampler_outputs
395-
396-
397- class LazyLogprobDict :
398- """An efficient dictionary-like interface required by vLLM's output processing.
399-
400- vLLM's output processor expects token probabilities to be provided as a dictionary
401- mapping token IDs to Logprob objects. However, creating this full dictionary is
402- computationally expensive, especially when dealing with large vocabulary sizes
403- (often 50k+ tokens).
404-
405- This class provides a compatible interface that satisfies vLLM's requirements while
406- avoiding the overhead.
407- """
408-
409- def __init__ (self , logprobs ):
410- self .logprobs = logprobs
411-
412- def __getitem__ (self , key ):
413- if 0 <= key < len (self .logprobs ):
414- return Logprob (self .logprobs [key ])
415- raise KeyError (key )
416-
417- def __contains__ (self , key ):
418- return 0 <= key < len (self .logprobs )
419-
420- def __len__ (self ):
421- return len (self .logprobs )
422-
423- def items (self ):
424- return ((i , Logprob (prob )) for i , prob in enumerate (self .logprobs ))
425-
426- def keys (self ):
427- return range (len (self .logprobs ))
428-
429- def values (self ):
430- return iter (map (Logprob , self .logprobs ))
431-
432- def get (self , key , default = None ):
433- try :
434- return self [key ]
435- except KeyError :
436- return default
284+ token_ids = list (output .outputs [0 ].token_ids )
285+ if token_ids [- 1 ] in eos_token_ids :
286+ token_ids = token_ids [:- 1 ]
287+ return token_ids
0 commit comments