Skip to content

Commit 136d853

Browse files
authored
[V1] Wrapper which plumbs request-level logits processors into vLLM batch-level logits processing (vllm-project#23656)
Signed-off-by: Andrew Feldman <[email protected]>
1 parent e32a0e8 commit 136d853

File tree

6 files changed

+524
-5
lines changed

6 files changed

+524
-5
lines changed
File renamed without changes.
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""This example demonstrates wrapping a request-level logits processor to be
5+
compatible with vLLM's batch-level logits processing
6+
7+
For demo purposes, a dummy logits processor is employed which, if
8+
`target_token` is passed as a keyword argument to `SamplingParams.extra_args`,
9+
will mask out all tokens except `target_token`. This logits processor can be
10+
applied to a vector of logits associated with a single decode step for a single
11+
request. The logits processor cannot be applied to a request which does not
12+
pass in a `target_token` custom argument.
13+
14+
The request-level dummy logits processor is wrapped to create a batch-level
15+
logits processor, which can apply the logits processor to output logits from
16+
all requests in the persistent batch in a given decode step. For requests which
17+
do not provide a `target_token` argument, the corresponding row of `logits`
18+
will not be modified.
19+
20+
A batch is constructed with `temperature=0.0` and 50% of requests specifying
21+
`target_token`, and for these requests - and *only* these requests - we
22+
expect the `target_token` to be decoded in each step, yielding an output
23+
similar to that shown below:
24+
25+
Generated Outputs:
26+
------------------------------------------------------------
27+
Prompt: 'Hello, my name is'
28+
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
29+
------------------------------------------------------------
30+
Prompt: 'The president of the United States is'
31+
Output: " not a racist. He is a racist.\nHe's a racist because he"
32+
------------------------------------------------------------
33+
Prompt: 'The capital of France is'
34+
Output: ' also also also also also also also also also also also also also
35+
also also also'
36+
------------------------------------------------------------
37+
Prompt: 'The future of AI is'
38+
Output: ' in the hands of the people.\n\nThe future of AI is in the'
39+
------------------------------------------------------------
40+
"""
41+
42+
from typing import Any, Optional
43+
44+
import torch
45+
46+
from vllm import LLM, SamplingParams
47+
from vllm.logger import init_logger
48+
from vllm.v1.sample.logits_processor import (
49+
AdapterLogitsProcessor,
50+
RequestLogitsProcessor,
51+
)
52+
53+
logger = init_logger(__name__)
54+
55+
56+
class DummyPerReqLogitsProcessor:
57+
"""The request-level logits processor masks out all logits except the
58+
token id identified by `target_token`"""
59+
60+
def __init__(self, target_token: int) -> None:
61+
"""Specify `target_token`"""
62+
self.target_token = target_token
63+
64+
def __call__(
65+
self,
66+
output_ids: list[int],
67+
logits: torch.Tensor,
68+
) -> torch.Tensor:
69+
val_to_keep = logits[self.target_token].item()
70+
logits[:] = float("-inf")
71+
logits[self.target_token] = val_to_keep
72+
return logits
73+
74+
75+
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
76+
"""Example of wrapping a fake request-level logit processor to create a
77+
batch-level logits processor"""
78+
79+
def is_argmax_invariant(self) -> bool:
80+
return False
81+
82+
def new_req_logits_processor(
83+
self,
84+
params: SamplingParams,
85+
) -> Optional[RequestLogitsProcessor]:
86+
"""This method returns a new request-level logits processor, customized
87+
to the `target_token` value associated with a particular request.
88+
89+
Returns None if the logits processor should not be applied to the
90+
particular request. To use the logits processor the request must have
91+
a "target_token" custom argument with an integer value.
92+
93+
Args:
94+
params: per-request sampling params
95+
96+
Returns:
97+
`Callable` request logits processor, or None
98+
"""
99+
target_token: Optional[Any] = params.extra_args and params.extra_args.get(
100+
"target_token"
101+
)
102+
if target_token is None:
103+
return None
104+
if not isinstance(target_token, int):
105+
logger.warning(
106+
"target_token value %s is not int; not applying logits"
107+
" processor to request.",
108+
target_token,
109+
)
110+
return None
111+
return DummyPerReqLogitsProcessor(target_token)
112+
113+
114+
# Sample prompts.
115+
prompts = [
116+
"Hello, my name is",
117+
"The president of the United States is",
118+
"The capital of France is",
119+
"The future of AI is",
120+
]
121+
# Create a mixture of requests which do and don't utilize the dummy logitproc
122+
sampling_params_list = [
123+
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
124+
SamplingParams(temperature=0.0),
125+
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
126+
SamplingParams(temperature=0.0),
127+
]
128+
129+
130+
def main():
131+
# Create an LLM.
132+
llm = LLM(
133+
model="facebook/opt-125m",
134+
logits_processors=[WrappedPerReqLogitsProcessor],
135+
)
136+
# Generate texts from the prompts.
137+
# The output is a list of RequestOutput objects
138+
# that contain the prompt, generated text, and other information.
139+
outputs = llm.generate(prompts, sampling_params_list)
140+
# Print the outputs.
141+
print("\nGenerated Outputs:\n" + "-" * 60)
142+
for output in outputs:
143+
prompt = output.prompt
144+
generated_text = output.outputs[0].text
145+
print(f"Prompt: {prompt!r}")
146+
print(f"Output: {generated_text!r}")
147+
print("-" * 60)
148+
149+
150+
if __name__ == "__main__":
151+
main()
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
"""This example demonstrates a special case of wrapping a request-level logits
5+
processor, namely the case where it is necessary to utilize engine config or
6+
environment info passed to the constructor. The subclass must override the
7+
wrapper base class `__init__()` method to access the engine config, the device
8+
identifier, or the flag which indicates whether pinned memory is available.
9+
10+
For demo purposes, a request-level dummy logits processor is employed which
11+
causes the same token (`target_token`) to be decoded in each step. The
12+
request-level dummy logits processor is wrapped to create a batch-level logits
13+
processor, which can apply the logits processor to output logits from all
14+
requests in the persistent batch in a given decode step.
15+
16+
The wrapped dummy logits processor below models a scenario where we must
17+
disable the logits processor on non-"cuda" platforms. The wrapper base class
18+
`__init__()` is overridden in order to check this condition and set a flag.
19+
20+
A batch is constructed with `temperature=0.0` and 50% of requests specifying
21+
`target_token`, and for these requests - and *only* these requests - we
22+
expect that on a "cuda" device the output will look something like:
23+
24+
Generated Outputs:
25+
------------------------------------------------------------
26+
Prompt: 'Hello, my name is'
27+
Output: " ' ' ' ' ' ' ' ' ' ' ' ' ' ' ' '"
28+
------------------------------------------------------------
29+
Prompt: 'The president of the United States is'
30+
Output: " not a racist. He is a racist.\nHe's a racist because he"
31+
------------------------------------------------------------
32+
Prompt: 'The capital of France is'
33+
Output: ' also also also also also also also also also also also also also
34+
also also also'
35+
------------------------------------------------------------
36+
Prompt: 'The future of AI is'
37+
Output: ' in the hands of the people.\n\nThe future of AI is in the'
38+
------------------------------------------------------------
39+
40+
which indicates that the logits processor is running. However, on a non-"cuda"
41+
device, the first and third requests would not repeat the same token.
42+
"""
43+
44+
from typing import Optional
45+
46+
import torch
47+
48+
from vllm import LLM, SamplingParams
49+
from vllm.config import VllmConfig
50+
from vllm.logger import init_logger
51+
from vllm.v1.sample.logits_processor import (
52+
AdapterLogitsProcessor,
53+
RequestLogitsProcessor,
54+
)
55+
56+
logger = init_logger(__name__)
57+
58+
59+
class DummyPerReqLogitsProcessor:
60+
"""The request-level logits processor masks out all logits except the
61+
token id identified by `target_token`"""
62+
63+
def __init__(self, target_token: int) -> None:
64+
"""Specify `target_token`"""
65+
self.target_token = target_token
66+
67+
def __call__(
68+
self,
69+
output_ids: list[int],
70+
logits: torch.Tensor,
71+
) -> torch.Tensor:
72+
val_to_keep = logits[self.target_token].item()
73+
logits[:] = float("-inf")
74+
logits[self.target_token] = val_to_keep
75+
return logits
76+
77+
78+
class WrappedPerReqLogitsProcessor(AdapterLogitsProcessor):
79+
"""Example of overriding the wrapper class `__init__()` in order to utilize
80+
info about the device type"""
81+
82+
def __init__(
83+
self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool
84+
):
85+
super().__init__(vllm_config, device, is_pin_memory)
86+
self.is_cuda = device.type == "cuda"
87+
88+
def is_argmax_invariant(self) -> bool:
89+
return False
90+
91+
def new_req_logits_processor(
92+
self,
93+
params: SamplingParams,
94+
) -> Optional[RequestLogitsProcessor]:
95+
"""This method returns a new request-level logits processor, customized
96+
to the `target_token` value associated with a particular request.
97+
98+
Returns None if the logits processor should not be applied to the
99+
particular request. To use the logits processor the request must have
100+
a "target_token" custom argument with an integer value, and the device
101+
must be "cuda"-type
102+
103+
Args:
104+
params: per-request sampling params
105+
106+
Returns:
107+
`Callable` request logits processor, or None
108+
"""
109+
if (
110+
not self.is_cuda
111+
or (
112+
target_token := params.extra_args
113+
and params.extra_args.get("target_token")
114+
)
115+
is None
116+
):
117+
return None
118+
if not isinstance(target_token, int):
119+
logger.warning(
120+
"target_token value %s is not int; not applying logits"
121+
" processor to request.",
122+
target_token,
123+
)
124+
return None
125+
return DummyPerReqLogitsProcessor(target_token)
126+
127+
128+
# Sample prompts.
129+
prompts = [
130+
"Hello, my name is",
131+
"The president of the United States is",
132+
"The capital of France is",
133+
"The future of AI is",
134+
]
135+
# Create a mixture of requests which do and don't utilize the dummy logitproc
136+
sampling_params_list = [
137+
SamplingParams(temperature=0.0, extra_args={"target_token": 128}),
138+
SamplingParams(temperature=0.0),
139+
SamplingParams(temperature=0.0, extra_args={"target_token": 67}),
140+
SamplingParams(temperature=0.0),
141+
]
142+
143+
144+
def main():
145+
# Create an LLM.
146+
llm = LLM(
147+
model="facebook/opt-125m",
148+
logits_processors=[WrappedPerReqLogitsProcessor],
149+
)
150+
# Generate texts from the prompts.
151+
# The output is a list of RequestOutput objects
152+
# that contain the prompt, generated text, and other information.
153+
outputs = llm.generate(prompts, sampling_params_list)
154+
# Print the outputs.
155+
print("\nGenerated Outputs:\n" + "-" * 60)
156+
for output in outputs:
157+
prompt = output.prompt
158+
generated_text = output.outputs[0].text
159+
print(f"Prompt: {prompt!r}")
160+
print(f"Output: {generated_text!r}")
161+
print("-" * 60)
162+
163+
164+
if __name__ == "__main__":
165+
main()

tests/v1/logits_processors/test_custom_offline.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
POOLING_MODEL_NAME, TEMP_GREEDY,
1616
CustomLogitprocSource,
1717
DummyLogitsProcessor,
18+
WrappedPerReqLogitsProcessor,
1819
dummy_module)
1920
from tests.v1.logits_processors.utils import entry_points as fake_entry_points
2021
from tests.v1.logits_processors.utils import prompts
@@ -161,6 +162,38 @@ def test_custom_logitsprocs(monkeypatch,
161162
_run_test(kwargs, logitproc_loaded=True)
162163

163164

165+
@create_new_process_for_each_test()
166+
def test_custom_logitsprocs_req(monkeypatch):
167+
"""Test passing request-level logits processor to offline Python interface
168+
169+
Wrap a request-level logits processor to create a batch level logits
170+
processor that has a well-defined behavior (mask out all tokens except one
171+
`target_token`)
172+
173+
Construct an `LLM` instance which loads the wrapped logits processor. Pass
174+
the custom logitproc as a class object.
175+
176+
Construct a reference `LLM` instance with no custom logitproc
177+
178+
Pass in a batch of requests, 50% of which pass a `target_token` value
179+
in through `SamplingParams.extra_args`, 50% of which do not.
180+
181+
Validate that
182+
* Requests which do not activate the custom logitproc, yield the same
183+
results for both `LLM` instances
184+
* Requests which activate the custom logitproc, only output `target_token`
185+
186+
Args:
187+
monkeypatch: for setting env vars
188+
"""
189+
190+
# Test that logitproc info is passed to workers
191+
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1")
192+
random.seed(40)
193+
_run_test({"logits_processors": [WrappedPerReqLogitsProcessor]},
194+
logitproc_loaded=True)
195+
196+
164197
@create_new_process_for_each_test()
165198
@pytest.mark.parametrize("logitproc_source", [
166199
CustomLogitprocSource.LOGITPROC_SOURCE_ENTRYPOINT,

0 commit comments

Comments
 (0)