-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Expand file tree
/
Copy pathvllm_text_completion.py
More file actions
208 lines (185 loc) · 7.42 KB
/
vllm_text_completion.py
File metadata and controls
208 lines (185 loc) · 7.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
""" A sample pipeline using the RunInference API to interface with an LLM using
vLLM. Takes in a set of prompts or lists of previous messages and produces
responses using a model of choice.
Requires a GPU runtime with vllm, openai, and apache-beam installed to run
correctly.
"""
import argparse
import logging
from collections.abc import Iterable
from typing import Optional
import apache_beam as beam
from apache_beam.ml.inference.base import PredictionResult
from apache_beam.ml.inference.base import RunInference
from apache_beam.ml.inference.vllm_inference import OpenAIChatMessage
from apache_beam.ml.inference.vllm_inference import VLLMChatModelHandler
from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.runner import PipelineResult
# Defaults avoid CUDA OOM on ~16GB GPUs (e.g. NVIDIA T4) with vLLM V1: the engine
# warms the sampler with many dummy sequences unless max_num_seqs is reduced, and
# the default gpu_memory_utilization can leave no free VRAM for that step.
_DEFAULT_VLLM_MAX_NUM_SEQS = 32
_DEFAULT_VLLM_GPU_MEMORY_UTILIZATION = 0.72
COMPLETION_EXAMPLES = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"John cena is",
]
CHAT_EXAMPLES = [
[
OpenAIChatMessage(
role='user', content='What is an example of a type of penguin?'),
OpenAIChatMessage(
role='assistant', content='Emperor penguin is a type of penguin.'),
OpenAIChatMessage(role='user', content='Tell me about them')
],
[
OpenAIChatMessage(
role='user', content='What colors are in the rainbow?'),
OpenAIChatMessage(
role='assistant',
content='Red, orange, yellow, green, blue, indigo, and violet.'),
OpenAIChatMessage(role='user', content='Do other colors ever appear?')
],
[
OpenAIChatMessage(
role='user', content='Who is the president of the United States?')
],
[
OpenAIChatMessage(role='user', content='What state is Fargo in?'),
OpenAIChatMessage(role='assistant', content='It is in North Dakota.'),
OpenAIChatMessage(role='user', content='How many people live there?'),
OpenAIChatMessage(
role='assistant',
content='Approximately 130,000 people live in Fargo, North Dakota.'
),
OpenAIChatMessage(role='user', content='What is Fargo known for?'),
],
[
OpenAIChatMessage(
role='user', content='How many fish are in the ocean?'),
],
]
def parse_known_args(argv):
"""Parses args for the workflow."""
parser = argparse.ArgumentParser()
parser.add_argument(
'--model',
dest='model',
type=str,
required=False,
default='facebook/opt-125m',
help='LLM to use for task')
parser.add_argument(
'--output',
dest='output',
type=str,
required=True,
help='Path to save output predictions.')
parser.add_argument(
'--chat',
dest='chat',
type=bool,
required=False,
default=False,
help='Whether to use chat model handler and examples')
parser.add_argument(
'--chat_template',
dest='chat_template',
type=str,
required=False,
default=None,
help='Chat template to use for chat example.')
parser.add_argument(
'--vllm_max_num_seqs',
dest='vllm_max_num_seqs',
type=int,
default=_DEFAULT_VLLM_MAX_NUM_SEQS,
help=(
'Passed to the vLLM OpenAI server as --max-num-seqs. '
'Lower values use less GPU memory during startup and inference; '
'required for many ~16GB GPUs (see --vllm_gpu_memory_utilization).'))
parser.add_argument(
'--vllm_gpu_memory_utilization',
dest='vllm_gpu_memory_utilization',
type=float,
default=_DEFAULT_VLLM_GPU_MEMORY_UTILIZATION,
help=(
'Passed to the vLLM OpenAI server as --gpu-memory-utilization '
'(fraction of total GPU memory for KV cache). Lower this if the '
'engine fails to start with CUDA out of memory.'))
return parser.parse_known_args(argv)
def build_vllm_server_kwargs(known_args) -> dict[str, str]:
"""Returns CLI flags for ``VLLMCompletionsModelHandler(..., vllm_server_kwargs=...)``."""
return {
'max-num-seqs': str(known_args.vllm_max_num_seqs),
'gpu-memory-utilization': str(known_args.vllm_gpu_memory_utilization),
}
class PostProcessor(beam.DoFn):
def process(self, element: PredictionResult) -> Iterable[str]:
yield str(element.example) + ": " + str(element.inference)
def run(
argv=None,
save_main_session=True,
test_pipeline=None,
vllm_server_kwargs: Optional[dict[str, str]] = None) -> PipelineResult:
"""
Args:
argv: Command line arguments defined for this example.
save_main_session: Used for internal testing.
test_pipeline: Used for internal testing.
vllm_server_kwargs: Optional override for vLLM server options. When None,
options are taken from argv (``--vllm_max_num_seqs``,
``--vllm_gpu_memory_utilization``). When set, argv tuning flags for the
server are ignored in favor of this dict (e.g. for programmatic use).
"""
known_args, pipeline_args = parse_known_args(argv)
pipeline_options = PipelineOptions(pipeline_args)
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
effective_vllm_kwargs = (
vllm_server_kwargs if vllm_server_kwargs is not None else
build_vllm_server_kwargs(known_args))
model_handler = VLLMCompletionsModelHandler(
model_name=known_args.model, vllm_server_kwargs=effective_vllm_kwargs)
input_examples = COMPLETION_EXAMPLES
if known_args.chat:
model_handler = VLLMChatModelHandler(
model_name=known_args.model,
chat_template_path=known_args.chat_template,
vllm_server_kwargs=dict(effective_vllm_kwargs))
input_examples = CHAT_EXAMPLES
pipeline = test_pipeline
if not test_pipeline:
pipeline = beam.Pipeline(options=pipeline_options)
examples = pipeline | "Create examples" >> beam.Create(input_examples)
predictions = examples | "RunInference" >> RunInference(model_handler)
process_output = predictions | "Process Predictions" >> beam.ParDo(
PostProcessor())
_ = process_output | "WriteOutput" >> beam.io.WriteToText(
known_args.output, shard_name_template='', append_trailing_newlines=True)
result = pipeline.run()
result.wait_until_finish()
return result
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
run()