generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 87
Expand file tree
/
Copy pathinput_parser.py
More file actions
293 lines (248 loc) · 12.2 KB
/
input_parser.py
File metadata and controls
293 lines (248 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#!/usr/bin/env python
#
# Copyright 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import logging
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Callable
from djl_python import Input
from djl_python.chat_completions.chat_utils import is_chat_completions_request, parse_chat_completions_request
from djl_python.encode_decode import decode
from djl_python.properties_manager.properties import is_rolling_batch_enabled
from djl_python.request import Request
from djl_python.request_io import TextInput, RequestInput
from djl_python.three_p.three_p_utils import parse_3p_request
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER = "X-Amzn-SageMaker-Adapter-Identifier"
def input_formatter(function):
"""
Decorator for input_formatter. User just need to annotate @input_formatter for their custom defined function.
:param function: Decorator takes in the function and adds an attribute.
:return:
"""
# adding an attribute to the function, which is used to find the decorated function.
function.is_input_formatter = True
return function
@dataclass
class ParsedInput:
errors: dict = field(default_factory=lambda: {})
requests: List[Request] = field(default_factory=lambda: [])
batch: List = field(default_factory=lambda: [])
def get_batch_start_id(batch, **kwargs):
if kwargs.get("is_rolling_batch"):
# for rolling batch, we only need to parse the new requests, as the active requests kept in cache.
rolling_batch = kwargs.get("rolling_batch")
active_requests_len = len(rolling_batch.active_requests)
batch_size = len(batch)
if batch_size > active_requests_len:
# if batch_size > active_requests_len, then new requests are received
return active_requests_len
else:
# no new requests are received, so sending batch_size, nothing will be parsed.
return batch_size
else:
# for non-rolling batch, python process only receives new requests.
return 0
def parse_input_with_formatter(inputs: Input, **kwargs) -> ParsedInput:
"""
Preprocessing function that extracts information from Input objects.
:param inputs :(Input) a batch of inputs, each corresponding to a new request
:return parsed_input: object of data class that contains all parsed input details
"""
errors = {}
requests = []
batch = inputs.get_batches()
configs = kwargs.get("configs")
kwargs["is_rolling_batch"] = is_rolling_batch_enabled(
configs.rolling_batch)
req_id_counter = get_req_id_counter(kwargs)
start_batch_id = get_batch_start_id(batch, **kwargs)
input_formatter_function = configs.input_formatter if configs.input_formatter else format_input
for i in range(start_batch_id, len(batch)):
input_item = batch[i]
client_request_id = input_item.get_property("requestId")
try:
# input formatter can be user written as well. We look for model.py and search for the decorator.
request_input = input_formatter_function(input_item, **kwargs)
# populate additional information in request_input
request_id = req_id_counter.next_id() if req_id_counter else i
request_input.request_id = request_id
request_input.client_request_id = client_request_id
request_input.tokenizer = kwargs.get("tokenizer")
request_input.tgi_compat = configs.tgi_compat
# We add server maintained parameters
add_server_maintained_params(request_input, input_item, **kwargs)
request = Request(request_input=request_input)
requests.append(request)
logging.info(
f"[RequestId={client_request_id}] parsed and scheduled for inference"
)
except Exception as e: # pylint: disable=broad-except
err_msg = "Input Parsing failed. Ensure that the request payload is valid. "
# str(e) for KeyError only yields the name of the key, which isn't useful as a response to the client
if isinstance(e, KeyError):
err_msg += f"Invalid Request Property: {e}"
else:
err_msg += str(e)
errors[i] = err_msg
logging.warning(f"[RequestId={client_request_id}" + err_msg,
exc_info=True)
continue
return ParsedInput(errors=errors, requests=requests, batch=batch)
def get_req_id_counter(kwargs):
req_id_counter = None
if kwargs.get("is_rolling_batch"):
req_id_counter = kwargs.get("rolling_batch").req_id_counter
return req_id_counter
def format_input(input_item: Input, **kwargs) -> RequestInput:
# TODO: Decide whether it is a text input based on content-type
request_input = TextInput()
content_type = input_item.get_property("Content-Type")
input_map = decode(input_item, content_type)
parse_text_inputs_params(request_input, input_item, input_map, **kwargs)
parse_adapters(request_input, input_item, input_map, **kwargs)
return request_input
def parse_text_inputs_params(request_input: TextInput, input_item: Input,
input_map: Dict, **kwargs):
invoke_type = input_item.get_property("X-Amzn-SageMaker-Forwarded-Api")
tokenizer = kwargs.get("tokenizer")
image_token = kwargs.get("image_placeholder_token")
configs = kwargs.get("configs")
is_mistral_tokenizer = kwargs.get("is_mistral_tokenizer", False)
is_rolling_batch = kwargs.get("is_rolling_batch", False)
is_bedrock = False
if configs is not None:
is_bedrock = configs.bedrock_compat
if is_chat_completions_request(input_map):
rolling_batch = kwargs.get("rolling_batch")
if rolling_batch is not None and rolling_batch.use_vllm_chat_completions(
):
# we only import this here as we know we're in a vllm context - ensures no bad imports for trtllm
from djl_python.chat_completions.vllm_chat_utils import parse_chat_completions_request_vllm
inputs, param = parse_chat_completions_request_vllm(
input_map,
rolling_batch,
tokenizer,
)
else:
inputs, param = parse_chat_completions_request(
input_map,
kwargs.get("is_rolling_batch"),
tokenizer,
image_token=image_token,
configs=configs,
is_mistral_tokenizer=is_mistral_tokenizer,
)
elif is_bedrock:
inputs, param = parse_3p_request(input_map,
kwargs.get("is_rolling_batch"),
tokenizer, invoke_type)
elif is_rolling_batch:
inputs, param = parse_lmi_default_request_rolling_batch(input_map)
else:
inputs = input_map.pop("inputs", input_map)
param = input_map.pop("parameters", {})
request_input.input_text = inputs
request_input.parameters = param
# TODO: Instead of modifying user parameters, maintain this in server_parameters.
# Added here for backward compatibility
# re-organize the parameters
if "cached_prompt" in input_map:
request_input.parameters["cached_prompt"] = input_map.pop(
"cached_prompt")
def add_server_maintained_params(request_input: RequestInput,
input_item: Input, **kwargs):
"""
Add some additional parameters for djl serving to do some work that are necessary.
:param request_input: request_input
:param input_item: Input object
:param kwargs: other parameters that are needed.
"""
request_input.server_parameters = request_input.parameters.copy()
# Per request streaming is only supported by rolling batch
if "seed" not in request_input.server_parameters:
# set server provided seed if seed is not part of request
if input_item.contains_key("seed"):
request_input.server_parameters["seed"] = input_item.get_as_string(
key="seed")
# setting the output formatter
if not "output_formatter" in request_input.server_parameters:
request_input.server_parameters["output_formatter"] = kwargs.get(
"configs").output_formatter
output_formatter = request_input.server_parameters["output_formatter"]
if output_formatter == "json" or output_formatter == "sse":
request_input.tgi_compat = kwargs.get("configs").tgi_compat
def parse_adapters(request_input: TextInput, input_item: Input,
input_map: Dict, **kwargs):
configs = kwargs.get("configs")
if hasattr(configs, "enable_lora") and configs.enable_lora:
adapter_registry = kwargs.get("adapter_registry")
input_len = len(request_input.input_text) if isinstance(
request_input.input_text, list) else 1
adapters_data = _fetch_adapters_from_input(input_map, input_item,
adapter_registry)
if adapters_data:
if input_len != len(adapters_data):
raise ValueError(
f"Number of adapters is not equal to the number of inputs")
if len(adapters_data) == 1:
adapters_data = adapters_data[0]
request_input.adapters = adapters_data
def _fetch_adapters_from_input(input_map: dict, input_item: Input,
adapter_registry):
adapters_per_item = []
if "adapters" in input_map:
adapters_per_item = input_map.pop("adapters", [])
# check content, possible in workflow approach
if input_item.contains_key("adapter"):
adapters_per_item = input_item.get_as_string("adapter")
# check properties, possible from header
adapter_alias = None
if SAGEMAKER_ADAPTER_IDENTIFIER_HEADER in input_item.get_properties():
adapters_per_item = input_item.get_property(
SAGEMAKER_ADAPTER_IDENTIFIER_HEADER)
adapter_alias = input_item.get_property(
"X-Amzn-SageMaker-Adapter-Alias")
logging.debug(f"Using adapter {adapter_alias or adapters_per_item}")
if not isinstance(adapters_per_item, list):
adapters_per_item = [adapters_per_item]
adapters_data = []
for adapter_name in adapters_per_item:
if adapter_name and adapter_name not in adapter_registry:
raise ValueError(
f"Adapter {adapter_alias or adapter_name} is not registered")
# lookup the adapter registry to get the adapter details of the registered adapter.
adapters_data.append(adapter_registry.get(adapter_name))
return adapters_data
def parse_lmi_default_request_rolling_batch(payload):
if not isinstance(payload, dict):
raise ValueError(
f"Invalid request payload. Request payload should be a json object specifying the 'inputs' field. Received payload {payload}"
)
inputs = payload.get("inputs", None)
if inputs is None:
raise ValueError(
f"Invalid request payload. Request payload should be a json object specifying the 'inputs' field. Received payload {payload}"
)
if not isinstance(inputs, str):
raise ValueError(
f"Invalid request payload. The 'inputs' field must be a string. Received type {type(inputs)}"
)
if len(inputs) == 0:
raise ValueError(
f"Invalid request payload. The 'inputs' field does not contain any content. Received payload {payload}"
)
parameters = payload.get("parameters", {})
if not isinstance(parameters, dict):
raise ValueError(
f"Invalid request payload. 'parameters' must be provided as an object of key-value pairs. Received payload {payload}"
)
parameters["stream"] = payload.get("stream", False)
return inputs, parameters