Skip to content

Commit 86e3e39

Browse files
authored
Improve vllm compatibility by using LogitProcessors to extract logprobs (#40)
* Use logit processor to extract logprobs. * Activate env at each step of coverage.yml workflow * try lower vllm version * Fix triton version to handle vllm error * Add sample method to AsyncLMs. * Update coverage.yml * Update docstrings. * Remove type anotation. * Specify args in test. * Rename and remove dead code. * Set temp higher to avoid vllm warning. * fix merge mistake * remove unused import * Increase tolerance in llm tests * tol
1 parent 7c52d56 commit 86e3e39

File tree

5 files changed

+81
-261
lines changed

5 files changed

+81
-261
lines changed

.github/workflows/coverage.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,16 @@ jobs:
2222
python-version: 3.11.5
2323
cache: 'pip'
2424

25-
- name: Run Tests
25+
- name: Install dependencies
2626
run: |
2727
python -m venv venv
2828
source venv/bin/activate
2929
pip install -e .[test]
3030
pip install -r requirements-dev.txt
31+
32+
- name: Run tests
33+
run: |
34+
source venv/bin/activate
3135
coverage run --source=genlm/backend -m pytest --benchmark-disable
3236
coverage json --omit "*/test*"
3337
coverage report --omit "*/test*"

genlm/backend/llm/vllm.py

Lines changed: 68 additions & 217 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import logging
33
import warnings
4-
from contextlib import contextmanager
54

65
from genlm.backend.llm.base import AsyncLM
76
from genlm.backend.cache import OutputCache
@@ -10,8 +9,6 @@
109
from vllm import AsyncLLMEngine, SamplingParams, AsyncEngineArgs
1110
from vllm.utils import Counter
1211
from vllm.inputs import TokensPrompt
13-
from vllm.model_executor.layers.sampler import SamplerOutput
14-
from vllm.sequence import SequenceOutput, CompletionSequenceGroupOutput, Logprob
1512

1613
from vllm.distributed.parallel_state import (
1714
destroy_model_parallel,
@@ -43,16 +40,27 @@ def from_name(cls, *args, **kwargs): # pragma: no cover
4340
else:
4441
logging.getLogger("vllm.engine.async_llm_engine").setLevel(logging.WARNING)
4542

46-
class AsyncVirtualLM(AsyncLM):
47-
"""A wrapper around vLLM's `AsyncLLMEngine` for asynchronous next token log probability computations.
43+
class PassThroughLogitsProcessor:
44+
"""A logits processor that stores the logprobs and passes the logits through."""
45+
46+
def __init__(self):
47+
self.log_probs = None
4848

49-
This class provides an asynchronous interface for computing log probabilities using vLLM's engine.
50-
It is optimized for next token log probability computations and supports caching of results (outputs and KV).
51-
"""
49+
def __call__(self, past_token_ids, logits):
50+
assert self.log_probs is None, (
51+
"Log probs already set. This should never happen."
52+
)
53+
self.log_probs = torch.log_softmax(logits, dim=-1, dtype=logits.dtype)
54+
return logits
5255

53-
default_params = SamplingParams(
54-
max_tokens=1, n=1, logprobs=1, detokenize=False, stop=None, ignore_eos=True
55-
)
56+
class AsyncVirtualLM(AsyncLM):
57+
default_params = {
58+
"max_tokens": 1,
59+
"n": 1,
60+
"detokenize": False,
61+
"stop": None,
62+
"ignore_eos": True,
63+
}
5664

5765
def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
5866
"""Initialize an `AsyncVirtualLM` instance.
@@ -68,8 +76,6 @@ def __init__(self, async_llm_engine, cache_size=0, cache_opts={}):
6876
self.async_llm_engine = async_llm_engine
6977
self.tokenizer = async_llm_engine.engine.get_tokenizer()
7078
self.request_counter = Counter()
71-
self.custom_sampler = DeferredSampler()
72-
self.original_sampler = self.underlying_model.sampler
7379
self.cache = (
7480
OutputCache(maxsize=cache_size, **cache_opts)
7581
if cache_size > 0
@@ -108,10 +114,7 @@ def from_name(cls, model_name, engine_opts=None, **kwargs):
108114
engine_opts = {
109115
"enable_prefix_caching": True,
110116
"disable_log_requests": True,
111-
"disable_async_output_proc": True,
112-
# Need to disable chunked prefill to avoid issues
113-
# with our custom sampler.
114-
"enable_chunked_prefill": False,
117+
"disable_async_output_proc": True, # This parameter forces vLLM to use v0, which is currently what we want to do.
115118
**(engine_opts or {}),
116119
}
117120

@@ -163,16 +166,21 @@ async def _next_token_logprobs(self, token_ids):
163166
prompt = TokensPrompt(prompt_token_ids=token_ids)
164167

165168
outputs = []
166-
with self._temporarily_set_sampler(self.custom_sampler):
167-
async for output in self.async_llm_engine.generate(
168-
prompt=prompt,
169-
sampling_params=self.default_params,
170-
request_id=req_id,
171-
):
172-
if output.finished:
173-
outputs.append(output)
174-
175-
return self._validate_outputs(outputs)
169+
processor = PassThroughLogitsProcessor()
170+
async for output in self.async_llm_engine.generate(
171+
prompt=prompt,
172+
sampling_params=SamplingParams(
173+
**self.default_params, logits_processors=[processor]
174+
),
175+
request_id=req_id,
176+
):
177+
if output.finished:
178+
outputs.append(output)
179+
180+
assert processor.log_probs is not None, (
181+
"Log probs should be set by the logits processor."
182+
)
183+
return processor.log_probs
176184

177185
def next_token_logprobs_sync(self, token_ids):
178186
"""Request log probabilities of next token synchronously.
@@ -196,69 +204,31 @@ def batch_next_token_logprobs_sync(self, token_ids_list):
196204
(torch.Tensor): A tensor of normalized log probability tensors, one for each prompt in the input list.
197205
"""
198206
req_ids = []
207+
req_id2processors = {}
199208
for token_ids in token_ids_list:
200209
req_id = str(next(self.request_counter))
201210
req_ids.append(req_id)
211+
processor = PassThroughLogitsProcessor()
212+
req_id2processors[req_id] = processor
202213
self.async_llm_engine.engine.add_request(
203214
prompt=TokensPrompt(prompt_token_ids=token_ids),
204-
params=self.default_params,
215+
params=SamplingParams(
216+
**self.default_params, logits_processors=[processor]
217+
),
205218
request_id=req_id,
206219
)
207220

208-
req_id2outputs = {}
209-
with self._temporarily_set_sampler(self.custom_sampler):
210-
while self.async_llm_engine.engine.has_unfinished_requests():
211-
output = self.async_llm_engine.engine.step()
212-
for out in output:
213-
if out.finished:
214-
assert out.request_id not in req_id2outputs, (
215-
f"Duplicate outputs for request {out.request_id}"
216-
)
217-
assert out.request_id in req_ids, (
218-
f"{out.request_id} not in requested IDs"
219-
)
220-
req_id2outputs[out.request_id] = out
221-
222-
logprobs = [
223-
self._validate_outputs([req_id2outputs[req_id]]) for req_id in req_ids
224-
]
225-
226-
return torch.stack(logprobs)
227-
228-
@contextmanager
229-
def _temporarily_set_sampler(self, sampler):
230-
"""Context manager for temporarily setting a custom sampler."""
231-
original_sampler = self.underlying_model.sampler
232-
try:
233-
self.underlying_model.sampler = sampler
234-
yield
235-
finally:
236-
self.underlying_model.sampler = original_sampler
237-
238-
def _validate_outputs(self, outputs):
239-
"""Validate and extract logprobs from a vLLM output.
240-
241-
Args:
242-
outputs: List of sequence group outputs from vLLM generation
243-
244-
Returns:
245-
Tensor of log probabilities for the next token
246-
247-
Raises:
248-
AssertionError: If output structure doesn't match expected format
249-
"""
250-
assert len(outputs) == 1, "Expected exactly one sequence group"
251-
seq_group = outputs[0]
221+
while self.async_llm_engine.engine.has_unfinished_requests():
222+
output = self.async_llm_engine.engine.step()
223+
for out in output:
224+
if out.finished:
225+
assert out.request_id in req_id2processors, (
226+
f"{out.request_id} not in requested IDs"
227+
)
252228

253-
assert len(seq_group.outputs) == 1, (
254-
"Expected exactly one sequence in output"
229+
return torch.stack(
230+
[req_id2processors[req_id].log_probs for req_id in req_ids]
255231
)
256-
sequence = seq_group.outputs[0]
257-
258-
assert len(sequence.logprobs) == 1, "Expected exactly one set of logprobs"
259-
token_logprobs = sequence.logprobs[0].logprobs
260-
261-
return token_logprobs
262232

263233
def clear_cache(self):
264234
"""Clear output cache."""
@@ -296,141 +266,22 @@ async def sample(
296266
Returns:
297267
(list[int]): The sampled token IDs.
298268
"""
299-
with self._temporarily_set_sampler(self.original_sampler):
300-
async for output in self.async_llm_engine.generate(
301-
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
302-
sampling_params=SamplingParams(
303-
n=1,
304-
max_tokens=max_tokens,
305-
temperature=temperature,
306-
seed=seed,
307-
stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
308-
),
309-
request_id=str(next(self.request_counter)),
310-
):
311-
if output.finished:
312-
assert len(output.outputs) == 1, (
313-
"Expected exactly one sequence group"
314-
)
315-
token_ids = list(output.outputs[0].token_ids)
316-
if token_ids[-1] in eos_token_ids:
317-
token_ids = token_ids[:-1]
318-
return token_ids
319-
320-
321-
class DeferredSampler(torch.nn.Module):
322-
"""A custom vLLM sampler optimized for efficient next-token probability calculations.
323-
324-
This sampler replaces vLLM's default sampling mechanism to optimize for scenarios
325-
where we only need the next token probabilities without actually sampling tokens.
326-
327-
Note:
328-
While this sampler implements vLLM's expected interface, it intentionally
329-
avoids actual token sampling to optimize for probability calculation use cases.
330-
It should not be used in scenarios where actual token generation is needed.
331-
"""
332-
333-
def __init__(self):
334-
super().__init__()
335-
336-
def forward(self, logits, sampling_metadata):
337-
"""Process model logits to create vLLM-compatible sampling outputs.
338-
339-
This method implements the required vLLM sampler interface but optimizes for
340-
probability requests.
341-
342-
Args:
343-
logits (torch.Tensor): Raw model logits with shape (num_tokens, vocab_size).
344-
sampling_metadata: vLLM metadata containing sequence grouping information.
345-
346-
Returns:
347-
SamplerOutput: A vLLM-compatible output structure containing:
348-
- Sequence group outputs with lazy probability dictionaries
349-
- Placeholder values for unused sampling fields
350-
- No actual sampled tokens (uses dummy token_id=0)
351-
352-
Note:
353-
The sampler uses token_id=0 as a placeholder.
354-
"""
355-
assert logits is not None
356-
357-
logprobs = logits.log_softmax(dim=-1, dtype=torch.float)
358-
359-
sample_idx = 0
360-
sampler_output = []
361-
for seq_group in sampling_metadata.seq_groups:
362-
seq_ids = seq_group.seq_ids
363-
num_parent_seqs = len(seq_ids)
364-
logprobs_by_seq = logprobs[sample_idx : sample_idx + num_parent_seqs]
365-
366-
if not seq_group.do_sample:
367-
sampler_output.append(
368-
CompletionSequenceGroupOutput(samples=[], prompt_logprobs=[])
369-
)
370-
else:
371-
assert len(logprobs_by_seq) == len(seq_ids)
372-
seq_outputs = []
373-
for seq_id, seq_logprobs in zip(seq_ids, logprobs_by_seq):
374-
seq_outputs.append(
375-
SequenceOutput(seq_id, 0, LazyLogprobDict(seq_logprobs))
269+
async for output in self.async_llm_engine.generate(
270+
prompt=TokensPrompt(prompt_token_ids=prompt_token_ids),
271+
sampling_params=SamplingParams(
272+
n=1,
273+
max_tokens=max_tokens,
274+
temperature=temperature,
275+
seed=seed,
276+
stop=[self.byte_vocab[i].decode() for i in eos_token_ids],
277+
),
278+
request_id=str(next(self.request_counter)),
279+
):
280+
if output.finished:
281+
assert len(output.outputs) == 1, (
282+
"Expected exactly one sequence group"
376283
)
377-
378-
sampler_output.append(
379-
CompletionSequenceGroupOutput(
380-
samples=seq_outputs, prompt_logprobs=[]
381-
)
382-
)
383-
384-
sample_idx += 1
385-
386-
sampler_outputs = SamplerOutput(
387-
outputs=sampler_output,
388-
sampled_token_probs=None,
389-
sampled_token_ids=None,
390-
logprobs=None,
391-
deferred_sample_results_args=None,
392-
)
393-
394-
return sampler_outputs
395-
396-
397-
class LazyLogprobDict:
398-
"""An efficient dictionary-like interface required by vLLM's output processing.
399-
400-
vLLM's output processor expects token probabilities to be provided as a dictionary
401-
mapping token IDs to Logprob objects. However, creating this full dictionary is
402-
computationally expensive, especially when dealing with large vocabulary sizes
403-
(often 50k+ tokens).
404-
405-
This class provides a compatible interface that satisfies vLLM's requirements while
406-
avoiding the overhead.
407-
"""
408-
409-
def __init__(self, logprobs):
410-
self.logprobs = logprobs
411-
412-
def __getitem__(self, key):
413-
if 0 <= key < len(self.logprobs):
414-
return Logprob(self.logprobs[key])
415-
raise KeyError(key)
416-
417-
def __contains__(self, key):
418-
return 0 <= key < len(self.logprobs)
419-
420-
def __len__(self):
421-
return len(self.logprobs)
422-
423-
def items(self):
424-
return ((i, Logprob(prob)) for i, prob in enumerate(self.logprobs))
425-
426-
def keys(self):
427-
return range(len(self.logprobs))
428-
429-
def values(self):
430-
return iter(map(Logprob, self.logprobs))
431-
432-
def get(self, key, default=None):
433-
try:
434-
return self[key]
435-
except KeyError:
436-
return default
284+
token_ids = list(output.outputs[0].token_ids)
285+
if token_ids[-1] in eos_token_ids:
286+
token_ids = token_ids[:-1]
287+
return token_ids

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ dependencies = [
1515
"accelerate",
1616
"bitsandbytes",
1717
"numba",
18-
"vllm>=0.6.6,<0.8.5; sys_platform == 'linux'",
18+
"vllm>=0.6.6,<=0.10.0; sys_platform == 'linux'",
19+
"triton==3.2.0"
1920
]
2021

2122
[project.optional-dependencies]

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def from_name(cls, model_name, llm_opts=None):
152152
llm_opts = {
153153
"enable_prefix_caching": True,
154154
"disable_log_stats": True,
155+
"dtype": "float16",
155156
**(llm_opts or {}),
156157
}
157158
llm = LLM(model=model_name, tokenizer=model_name, **llm_opts)

0 commit comments

Comments
 (0)