generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 87
Expand file tree
/
Copy pathvllm_rb_properties.py
More file actions
249 lines (225 loc) · 10.7 KB
/
vllm_rb_properties.py
File metadata and controls
249 lines (225 loc) · 10.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
#!/usr/bin/env python
#
# Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import ast
import logging
from typing import Optional, Any, Dict, Tuple, Literal, Union
from pydantic import field_validator, model_validator, ConfigDict, Field
from vllm import EngineArgs, AsyncEngineArgs
from vllm.utils.argparse_utils import FlexibleArgumentParser
from djl_python.properties_manager.properties import Properties
DTYPE_MAPPER = {
"float32": "float32",
"fp32": "float32",
"float16": "float16",
"fp16": "float16",
"bfloat16": "bfloat16",
"bf16": "bfloat16",
"auto": "auto"
}
def construct_vllm_args_list(vllm_engine_args: dict):
# Modified from https://github.com/vllm-project/vllm/blob/94666612a938380cb643c1555ef9aa68b7ab1e53/vllm/utils/argparse_utils.py#L441
args_list = []
for key, value in vllm_engine_args.items():
if str(value).lower() in {'true', 'false'}:
if str(value).lower() == 'true':
args_list.append("--" + key)
elif isinstance(value, bool):
if value:
args_list.append("--" + key)
elif isinstance(value, list):
if value:
args_list.append("--" + key)
for item in value:
args_list.append(str(item))
else:
args_list.append("--" + key)
args_list.append(str(value))
return args_list
class VllmRbProperties(Properties):
engine: Optional[str] = None
# The following configs have different names in DJL compared to vLLM, we only accept DJL name currently
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
# The following configs have different names in DJL compared to vLLM, either is accepted
quantize: Optional[str] = Field(alias="quantization",
default=EngineArgs.quantization)
max_rolling_batch_prefill_tokens: Optional[int] = Field(
alias="max_num_batched_tokens",
default=EngineArgs.max_num_batched_tokens)
cpu_offload_gb_per_gpu: float = Field(alias="cpu_offload_gb",
default=EngineArgs.cpu_offload_gb)
# The following configs have different defaults, or additional processing in DJL compared to vLLM
dtype: str = "auto"
max_loras: int = 4
task: str = 'auto'
# The following configs have broken processing in vllm via the FlexibleArgumentParser
long_lora_scaling_factors: Optional[Tuple[float, ...]] = None
# Tool calling properties
enable_auto_tool_choice: bool = False
tool_call_parser: Optional[str] = None
# Reasoning properties
enable_reasoning: bool = False
reasoning_parser: Optional[str] = None
# Neuron vLLM properties
device: str = 'auto'
preloaded_model: Optional[Any] = None
override_neuron_config: Optional[Dict] = None
# Non engine arg properties
chat_template: Optional[str] = None
chat_template_content_format: Literal["auto", "string", "openai"] = "auto"
# This allows generic vllm engine args to be passed in and set with vllm
model_config = ConfigDict(extra='allow', populate_by_name=True)
@field_validator('engine')
def validate_engine(cls, engine):
if engine != "Python":
raise AssertionError(
f"Need python engine to start vLLM RollingBatcher")
return engine
@field_validator('task')
def validate_task(cls, task):
# TODO: conflicts between HF and VLLM tasks, need to separate these.
# for backwards compatibility, max text-generation to generate
if task == 'text-generation':
task = 'generate'
return task
@field_validator('dtype')
def validate_dtype(cls, val):
if val not in DTYPE_MAPPER:
raise ValueError(
f"Invalid dtype={val} provided. Must be one of {DTYPE_MAPPER.keys()}"
)
return DTYPE_MAPPER[val]
@model_validator(mode='after')
def validate_pipeline_parallel(self):
if self.pipeline_parallel_degree != 1:
raise ValueError(
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self
@model_validator(mode='after')
def validate_tool_call_parser(self):
if self.enable_auto_tool_choice:
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
valid_tool_parses = ToolParserManager.tool_parsers.keys()
if self.tool_call_parser not in valid_tool_parses:
raise ValueError(
f"Invalid tool call parser: {self.tool_call_parser} "
f"(chose from {{ {','.join(valid_tool_parses)} }})")
return self
@model_validator(mode='after')
def validate_reasoning_parser(self):
if self.enable_reasoning:
from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys(
)
if self.reasoning_parser not in valid_reasoning_parses:
raise ValueError(
f"Invalid reasoning parser: {self.reasoning_parser} "
f"(chose from {{ {','.join(valid_reasoning_parses)} }})")
return self
@field_validator('override_neuron_config', mode="before")
def validate_override_neuron_config(cls, val):
if isinstance(val, str):
neuron_config = ast.literal_eval(val)
if not isinstance(neuron_config, dict):
raise ValueError(
f"Invalid json format for override_neuron_config")
return neuron_config
elif isinstance(val, Dict):
return val
else:
raise ValueError("Invalid json format for override_neuron_config")
# TODO: processing of this field is broken in vllm via from_cli_args
# we should upstream a fix for this to vllm
@field_validator('long_lora_scaling_factors', mode='before')
def validate_long_lora_scaling_factors(cls, val):
if isinstance(val, str):
val = ast.literal_eval(val)
if not isinstance(val, tuple):
if isinstance(val, list):
val = tuple(float(v) for v in val)
elif isinstance(val, float):
val = (val, )
elif isinstance(val, int):
val = (float(val), )
else:
raise ValueError(
"long_lora_scaling_factors must be convertible to a tuple of floats."
)
return val
def handle_lmi_vllm_config_conflicts(self, additional_vllm_engine_args):
def validate_potential_lmi_vllm_config_conflict(
lmi_config_name, vllm_config_name):
lmi_config_val = self.__getattribute__(lmi_config_name)
vllm_config_val = additional_vllm_engine_args.get(vllm_config_name)
if vllm_config_val is not None and lmi_config_val is not None:
if vllm_config_val != lmi_config_val:
raise ValueError(
f"Both the DJL {lmi_config_val}={lmi_config_val} and vLLM {vllm_config_name}={vllm_config_val} configs have been set with conflicting values."
f"We currently only accept the DJL config {lmi_config_name}, please remove the vllm {vllm_config_name} configuration."
)
validate_potential_lmi_vllm_config_conflict("tensor_parallel_degree",
"tensor_parallel_size")
validate_potential_lmi_vllm_config_conflict("pipeline_parallel_degree",
"pipeline_parallel_size")
validate_potential_lmi_vllm_config_conflict("max_rolling_batch_size",
"max_num_seqs")
def generate_vllm_engine_arg_dict(self,
passthrough_vllm_engine_args) -> dict:
vllm_engine_args = {
'model': self.model_id_or_path,
'tensor_parallel_size': self.tensor_parallel_degree,
'pipeline_parallel_size': self.pipeline_parallel_degree,
'max_num_seqs': self.max_rolling_batch_size,
'dtype': DTYPE_MAPPER[self.dtype],
'revision': self.revision,
'max_loras': self.max_loras,
'enable_lora': self.enable_lora,
'trust_remote_code': self.trust_remote_code,
'cpu_offload_gb': self.cpu_offload_gb_per_gpu,
'quantization': self.quantize,
}
if self.max_rolling_batch_prefill_tokens is not None:
vllm_engine_args[
'max_num_batched_tokens'] = self.max_rolling_batch_prefill_tokens
vllm_engine_args.update(passthrough_vllm_engine_args)
return vllm_engine_args
def get_engine_args(self,
async_engine=False
) -> Union[EngineArgs, AsyncEngineArgs]:
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict(
additional_vllm_engine_args)
logging.debug(
f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}"
)
arg_cls = AsyncEngineArgs if async_engine else EngineArgs
parser = arg_cls.add_cli_args(FlexibleArgumentParser())
args_list = construct_vllm_args_list(vllm_engine_arg_dict)
args = parser.parse_args(args=args_list)
engine_args = arg_cls.from_cli_args(args)
# we have to do this separately because vllm converts it into a string
engine_args.long_lora_scaling_factors = self.long_lora_scaling_factors
# These neuron configs are not implemented in the vllm arg parser
if self.device == 'neuron':
setattr(engine_args, 'preloaded_model', self.preloaded_model)
setattr(engine_args, 'override_neuron_config',
self.override_neuron_config)
return engine_args
def get_additional_vllm_engine_args(self) -> Dict[str, Any]:
return {
k: v
for k, v in self.__pydantic_extra__.items()
if k in EngineArgs.__annotations__
}