Skip to content

Commit d3561e5

Browse files
ksuma2109lxningxyang16smouaa
authored
Recent changes required for LMIv17rc7 (#2952)
Co-authored-by: lxning <23464292+lxning@users.noreply.github.com> Co-authored-by: Xin Yang <105740670+xyang16@users.noreply.github.com> Co-authored-by: sheng moua <127175097+smouaa@users.noreply.github.com>
1 parent 603d086 commit d3561e5

File tree

16 files changed

+1172
-148
lines changed

16 files changed

+1172
-148
lines changed

.github/workflows/docker_publish.yml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,25 @@ jobs:
7878
--fail \
7979
| jq '.token' | tr -d '"' )
8080
./start_instance.sh action_cpu $token djl-serving
81+
- name: Create new Graviton instance
82+
id: create_aarch64
83+
run: |
84+
cd /home/ubuntu/djl_benchmark_script/scripts
85+
token=$( curl -X POST -H "Authorization: token ${{ secrets.ACTION_RUNNER_PERSONAL_TOKEN }}" \
86+
https://api.github.com/repos/deepjavalibrary/djl-serving/actions/runners/registration-token \
87+
--fail \
88+
| jq '.token' | tr -d '"' )
89+
./start_instance.sh action_graviton $token djl-serving
8190
outputs:
8291
cpu_instance_id1: ${{ steps.create_cpu_1.outputs.action_cpu_instance_id }}
8392
cpu_instance_id2: ${{ steps.create_cpu_2.outputs.action_cpu_instance_id }}
8493
cpu_instance_id3: ${{ steps.create_cpu_3.outputs.action_cpu_instance_id }}
94+
aarch64_instance_id: ${{ steps.create_aarch64.outputs.action_graviton_instance_id }}
8595

8696
docker-sync:
8797
runs-on:
8898
- self-hosted
89-
- cpu
99+
- ${{ matrix.arch != 'aarch64' && 'cpu' || 'aarch64' }}
90100
- RUN_ID-${{ github.run_id }}
91101
- RUN_NUMBER-${{ github.run_number }}
92102
- SHA-${{ github.sha }}
@@ -154,3 +164,5 @@ jobs:
154164
./stop_instance.sh $instance_id
155165
instance_id=${{ needs.create-runners.outputs.cpu_instance_id3 }}
156166
./stop_instance.sh $instance_id
167+
instance_id=${{ needs.create-runners.outputs.aarch64_instance_id }}
168+
./stop_instance.sh $instance_id

.github/workflows/integration.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,9 @@ jobs:
161161
- test: TestGpu_g6
162162
instance: g6
163163
failure-prefix: gpu
164-
- test: TestAarch64
165-
instance: aarch64
166-
failure-prefix: aarch64
164+
# - test: TestAarch64
165+
# instance: aarch64
166+
# failure-prefix: aarch64
167167
# - test: TestHfHandler_g6
168168
# instance: g6
169169
# failure-prefix: lmi
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
import json
2+
from typing import Callable, Union, Tuple, List
3+
from tensorrt_llm.serve.openai_protocol import (
4+
ErrorResponse,
5+
ChatCompletionRequest,
6+
ChatCompletionResponse,
7+
CompletionResponse,
8+
CompletionRequest,
9+
CompletionLogProbs,
10+
)
11+
from tensorrt_llm.llmapi.tokenizer import TokenizerBase
12+
from djl_python.async_utils import create_non_stream_output
13+
from djl_python.outputs import Output
14+
15+
16+
class ProcessedRequest:
17+
18+
def __init__(
19+
self,
20+
trtllm_request: Union[CompletionRequest, ChatCompletionRequest],
21+
inference_invoker: Callable,
22+
non_stream_output_formatter: Callable,
23+
stream_output_formatter: Callable,
24+
accumulate_chunks: bool,
25+
include_prompt: bool,
26+
):
27+
self.trtllm_request = trtllm_request
28+
self.inference_invoker = inference_invoker
29+
# We need access to both the stream and non-stream output formatters here
30+
# because even with streaming requests, there may be some errors before inference that
31+
# result in a return of ErrorResponse object instead of AsyncGenerator
32+
self.non_stream_output_formatter = non_stream_output_formatter
33+
self.stream_output_formatter = stream_output_formatter
34+
self.accumulate_chunks = accumulate_chunks
35+
self.include_prompt = include_prompt
36+
self.lora_request = None
37+
38+
39+
def convert_lmi_schema_to_completion_request(
40+
payload: dict, ) -> Tuple[CompletionRequest, bool, bool]:
41+
parameters = payload.get("parameters", {})
42+
43+
completion_dict = {
44+
"prompt": payload.pop("inputs"),
45+
"model": payload.pop("model"),
46+
"max_tokens": parameters.pop("max_new_tokens", 30),
47+
"echo": parameters.pop("return_full_text", False),
48+
"truncate_prompt_tokens": parameters.pop("truncate", None),
49+
"n": parameters.pop("top_n_tokens", 1),
50+
"ignore_eos": parameters.pop("ignore_eos_token", False),
51+
"stream": payload.pop("stream", False),
52+
}
53+
# TRTLLM does not support logprobs in completions API. If provided, rely on TRTLLM validation error
54+
include_details_in_response = False
55+
include_prompt = False
56+
if completion_dict["stream"]:
57+
completion_dict["stream_options"] = {
58+
"include_usage": True,
59+
"continuous_usage_stats": True
60+
}
61+
include_prompt = completion_dict.pop("echo", False)
62+
if parameters.pop("details", False):
63+
include_details_in_response = True
64+
if parameters.pop("decoder_input_details", False):
65+
completion_dict["return_context_logits"] = 1
66+
do_sample = parameters.pop("do_sample", None)
67+
# when do_sample is None, just passthrough sampling params as sampling is dictated by the value of other params
68+
# when do_sample is False, set sampling params such that we disable sampling
69+
if do_sample is not None and not do_sample:
70+
parameters["temperature"] = 0.0
71+
72+
completion_dict.update(parameters)
73+
74+
return CompletionRequest(
75+
**completion_dict), include_details_in_response, include_prompt
76+
77+
78+
def convert_completion_response_to_lmi_schema(
79+
response: CompletionResponse,
80+
request: CompletionRequest = None,
81+
include_details: bool = False,
82+
tokenizer: TokenizerBase = None) -> Output:
83+
primary_choice = response.choices[0]
84+
lmi_response = {"generated_text": primary_choice.text}
85+
if not include_details:
86+
return create_non_stream_output(lmi_response)
87+
details = {
88+
"finish_reason": primary_choice.stop_reason,
89+
"generated_tokens": response.usage.completion_tokens,
90+
"seed": request.seed,
91+
}
92+
lmi_response["details"] = details
93+
output = create_non_stream_output(lmi_response)
94+
return output
95+
96+
97+
def convert_completion_chunk_response_to_lmi_schema(
98+
chunk: str,
99+
include_details: bool = False,
100+
history: List[str] = None,
101+
request: CompletionRequest = None,
102+
include_prompt: bool = False,
103+
tokenizer: TokenizerBase = None,
104+
**_,
105+
) -> Tuple[str, bool, List[str]]:
106+
# TRTLLM returns chunks in string format, and the conversion process to TGI
107+
# currently converts the string to an object, and then the object back to a string.
108+
# It's much easier to work with the object instead of manipulating the string, but inefficient
109+
trimmed_chunk = chunk[6:].strip()
110+
if trimmed_chunk == '[DONE]':
111+
data = ""
112+
return data, True, history
113+
114+
trt_completion_chunk = json.loads(trimmed_chunk)
115+
if "error" in trt_completion_chunk:
116+
return json.dumps(trt_completion_chunk,
117+
ensure_ascii=False), True, history
118+
119+
if len(trt_completion_chunk["choices"]) == 0:
120+
# penultimate chunk
121+
return "", False, history
122+
choice = trt_completion_chunk["choices"][0]
123+
index = choice["index"]
124+
token_text = choice["text"]
125+
history.append(token_text)
126+
finish_reason = choice["finish_reason"]
127+
stop_reason = choice["stop_reason"]
128+
usage = trt_completion_chunk["usage"]
129+
130+
# TODO: TokenId and LogProb here
131+
token = {
132+
"id": None,
133+
"text": token_text,
134+
"logprob": None,
135+
}
136+
tgi_chunk = {
137+
"index": index,
138+
"token": token,
139+
"generated_text": None,
140+
"details": None,
141+
}
142+
generation_finished = finish_reason is not None or stop_reason is not None
143+
if generation_finished:
144+
generated_text = ''.join(history)
145+
if include_prompt:
146+
generated_text = request.prompt + generated_text
147+
tgi_chunk["generated_text"] = generated_text
148+
if include_details:
149+
details = {
150+
"finish_reason": finish_reason or stop_reason,
151+
"seed": request.seed,
152+
"generated_tokens": usage["completion_tokens"] + 1,
153+
"input_length": usage["prompt_tokens"],
154+
}
155+
tgi_chunk["details"] = details
156+
json_str = json.dumps(tgi_chunk, ensure_ascii=False)
157+
return json_str, False, history
158+
159+
160+
def lmi_with_details_non_stream_output_formatter(
161+
response: CompletionResponse,
162+
request: CompletionRequest = None,
163+
tokenizer: TokenizerBase = None,
164+
) -> Output:
165+
return convert_completion_response_to_lmi_schema(response,
166+
include_details=True,
167+
request=request,
168+
tokenizer=tokenizer)
169+
170+
171+
def lmi_non_stream_output_formatter(
172+
response: CompletionResponse,
173+
request: CompletionRequest = None,
174+
tokenizer: TokenizerBase = None,
175+
) -> Output:
176+
return convert_completion_response_to_lmi_schema(response,
177+
include_details=False,
178+
request=request,
179+
tokenizer=tokenizer)
180+
181+
182+
def lmi_with_details_stream_output_formatter(
183+
chunk: str,
184+
**kwargs,
185+
) -> Tuple[str, bool, List[str]]:
186+
return convert_completion_chunk_response_to_lmi_schema(
187+
chunk, include_details=True, **kwargs)
188+
189+
190+
def lmi_stream_output_formatter(
191+
chunk: str,
192+
**kwargs,
193+
) -> Tuple[str, bool, List[str]]:
194+
return convert_completion_chunk_response_to_lmi_schema(chunk, **kwargs)
195+
196+
197+
def trtllm_non_stream_output_formatter(
198+
response: Union[ErrorResponse, ChatCompletionResponse, CompletionResponse],
199+
**_,
200+
) -> Output:
201+
if isinstance(response, ErrorResponse):
202+
return create_non_stream_output("",
203+
error=response.message,
204+
code=response.code)
205+
response_data = response.model_dump_json()
206+
return create_non_stream_output(response_data)
207+
208+
209+
def trtllm_stream_output_formatter(
210+
chunk: str,
211+
**_,
212+
) -> Tuple[str, bool]:
213+
# trtllm returns responses in sse format, 'data: {...}'
214+
trimmed_chunk = chunk[6:].strip()
215+
if trimmed_chunk == '[DONE]':
216+
data = ""
217+
last = True
218+
else:
219+
data = trimmed_chunk
220+
last = False
221+
return data, last

0 commit comments

Comments
 (0)