Skip to content

Commit f47f2fb

Browse files
authored
[Inference] Fix API server, test and example (#5712)
* fix api server * fix generation config * fix api server * fix comments * fix infer hanging bug * resolve comments, change backend to free port
1 parent 74c4792 commit f47f2fb

File tree

5 files changed

+73
-32
lines changed

5 files changed

+73
-32
lines changed

colossalai/inference/core/async_engine.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import AsyncIterator, Dict, Iterable, List, Optional, Set, Tuple, Type
55

66
from colossalai.inference.core.engine import InferenceEngine
7+
from colossalai.inference.sampler import search_tokens
78

89
# CLI logger
910
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
@@ -168,26 +169,44 @@ async def async_step(self) -> List[str]:
168169
generated results.
169170
"""
170171
batch = self.request_handler.schedule()
172+
input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
173+
171174
loop = asyncio.get_running_loop()
172175

176+
if input_meta_data.use_cuda_graph:
177+
model_executable = self.graph_runners[input_meta_data.batch_size]
178+
else:
179+
model_executable = self.model
180+
173181
# Use run_in_executor to asyncally run the sync method model.forward().
174182
logits = await loop.run_in_executor(
175183
None,
176-
self.model,
177-
batch,
184+
model_executable,
185+
input_token_ids,
186+
output_tensor,
187+
input_meta_data,
178188
self.k_cache,
179189
self.v_cache,
180190
)
181191

182192
if self.inference_config.pad_input:
183193
logits = logits[:, -1, :]
184-
self.request_handler.search_tokens(self.generation_config, logits)
194+
next_tokens = search_tokens(
195+
self.generation_config, logits, input_meta_data.is_prompts, batch_token_ids=input_meta_data.batch_token_ids
196+
)
185197

198+
self.request_handler.append_next_tokens(next_tokens)
186199
finished_sequences = self.request_handler.update()
200+
187201
for sequence in finished_sequences:
188202
sequence.output = self.tokenizer.decode(sequence.output_token_id)
189203

190-
return finished_sequences, self.request_handler.total_requests_in_batch_bucket() > 0
204+
return finished_sequences, not self.request_handler.running_list.is_empty()
205+
206+
def add_single_request(self, request_id: int, prompt: str, prompt_token_ids, generation_config=None):
207+
prompts = [prompt]
208+
gen_config_dict = generation_config.to_dict() if generation_config is not None else {}
209+
self.add_request(request_ids=request_id, prompts=prompts, prompts_token_ids=prompt_token_ids, **gen_config_dict)
191210

192211

193212
class AsyncInferenceEngine:
@@ -240,7 +259,6 @@ async def step(self):
240259
for new_request in new_requests:
241260
self.engine.add_single_request(**new_request)
242261
newly_finished_seqs, has_running_requests = await self.engine.async_step()
243-
244262
for seq in newly_finished_seqs:
245263
self._request_tracer.process_finished_request(seq)
246264

@@ -273,6 +291,7 @@ async def add_request(
273291
request_id: int,
274292
prompt: Optional[str],
275293
prompt_token_ids: Optional[List[int]] = None,
294+
generation_config=None,
276295
) -> RequstStream:
277296
"""
278297
Add a request to the background tracker(waiting queue), start the background loop if needed.
@@ -286,6 +305,7 @@ async def add_request(
286305
request_id,
287306
prompt=prompt,
288307
prompt_token_ids=prompt_token_ids,
308+
generation_config=generation_config,
289309
)
290310
return stream
291311

@@ -294,13 +314,16 @@ async def generate(
294314
request_id: int,
295315
prompt: Optional[str],
296316
prompt_token_ids: Optional[List[int]] = None,
317+
generation_config=None,
297318
) -> AsyncIterator[str]:
298319
"""
299320
Generate output from a request. It receives the request from http server, adds it into the
300321
waitting queue of Async Engine and streams the output sequence.
301322
"""
302323
try:
303-
stream = await self.add_request(request_id, prompt, prompt_token_ids=prompt_token_ids)
324+
stream = await self.add_request(
325+
request_id, prompt, prompt_token_ids=prompt_token_ids, generation_config=generation_config
326+
)
304327
return await stream.get_result()
305328

306329
except (Exception, asyncio.CancelledError) as e:

colossalai/inference/core/engine.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy
154154
else:
155155
model_type = "nopadding_" + self.model_config.model_type
156156
model_policy = model_policy_map[model_type]()
157-
158157
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
159158
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
160159

@@ -589,7 +588,7 @@ def format_prompt(self, prompts: Union[List[str], str]) -> Union[List[str], str]
589588
def add_request(
590589
self,
591590
request_ids: Union[List[int], int] = None,
592-
prompts: List[str] = None,
591+
prompts: Union[List[str], str] = None,
593592
prompts_token_ids: Union[List[int], torch.Tensor, np.ndarray] = None,
594593
**kwargs,
595594
) -> None:

colossalai/inference/server/api_server.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@
2020
from fastapi.responses import JSONResponse, Response, StreamingResponse
2121
from transformers import AutoModelForCausalLM, AutoTokenizer
2222

23+
import colossalai
2324
from colossalai.inference.config import InferenceConfig
2425
from colossalai.inference.server.chat_service import ChatServing
2526
from colossalai.inference.server.completion_service import CompletionServing
2627
from colossalai.inference.server.utils import id_generator
28+
from colossalai.inference.utils import find_available_ports
2729

2830
from colossalai.inference.core.async_engine import AsyncInferenceEngine, InferenceEngine # noqa
2931

@@ -54,8 +56,9 @@ async def generate(request: Request) -> Response:
5456
"""
5557
request_dict = await request.json()
5658
prompt = request_dict.pop("prompt")
57-
stream = request_dict.pop("stream", "false").lower()
58-
59+
stream = request_dict.pop("stream", "false")
60+
if isinstance(stream, str):
61+
stream = stream.lower()
5962
request_id = id_generator()
6063
generation_config = get_generation_config(request_dict)
6164
results = engine.generate(request_id, prompt, generation_config=generation_config)
@@ -66,7 +69,7 @@ def stream_results():
6669
ret = {"text": request_output[len(prompt) :]}
6770
yield (json.dumps(ret) + "\0").encode("utf-8")
6871

69-
if stream == "true":
72+
if stream == "true" or stream == True:
7073
return StreamingResponse(stream_results())
7174

7275
# Non-streaming case
@@ -86,12 +89,14 @@ def stream_results():
8689
@app.post("/completion")
8790
async def create_completion(request: Request):
8891
request_dict = await request.json()
89-
stream = request_dict.pop("stream", "false").lower()
92+
stream = request_dict.pop("stream", "false")
93+
if isinstance(stream, str):
94+
stream = stream.lower()
9095
generation_config = get_generation_config(request_dict)
9196
result = await completion_serving.create_completion(request, generation_config)
9297

9398
ret = {"request_id": result.request_id, "text": result.output}
94-
if stream == "true":
99+
if stream == "true" or stream == True:
95100
return StreamingResponse(content=json.dumps(ret) + "\0", media_type="text/event-stream")
96101
else:
97102
return JSONResponse(content=ret)
@@ -101,10 +106,12 @@ async def create_completion(request: Request):
101106
async def create_chat(request: Request):
102107
request_dict = await request.json()
103108

104-
stream = request_dict.get("stream", "false").lower()
109+
stream = request_dict.get("stream", "false")
110+
if isinstance(stream, str):
111+
stream = stream.lower()
105112
generation_config = get_generation_config(request_dict)
106113
message = await chat_serving.create_chat(request, generation_config)
107-
if stream == "true":
114+
if stream == "true" or stream == True:
108115
return StreamingResponse(content=message, media_type="text/event-stream")
109116
else:
110117
ret = {"role": message.role, "text": message.content}
@@ -115,27 +122,29 @@ def get_generation_config(request):
115122
generation_config = async_engine.engine.generation_config
116123
for arg in request:
117124
if hasattr(generation_config, arg):
118-
generation_config[arg] = request[arg]
125+
setattr(generation_config, arg, request[arg])
119126
return generation_config
120127

121128

122129
def add_engine_config(parser):
123-
parser.add_argument("--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use")
124-
125130
parser.add_argument(
126-
"--max-model-len",
127-
type=int,
128-
default=None,
129-
help="model context length. If unspecified, " "will be automatically derived from the model.",
131+
"-m", "--model", type=str, default="llama2-7b", help="name or path of the huggingface model to use"
130132
)
131-
# Parallel arguments
132-
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1, help="number of tensor parallel replicas")
133+
# Parallel arguments not supported now
133134

134135
# KV cache arguments
135136
parser.add_argument("--block-size", type=int, default=16, choices=[8, 16, 32], help="token block size")
136137

137138
parser.add_argument("--max_batch_size", type=int, default=8, help="maximum number of batch size")
138139

140+
parser.add_argument("-i", "--max_input_len", type=int, default=128, help="max input length")
141+
142+
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="max output length")
143+
144+
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
145+
146+
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
147+
139148
# generation arguments
140149
parser.add_argument(
141150
"--prompt_template",
@@ -150,7 +159,7 @@ def parse_args():
150159
parser = argparse.ArgumentParser(description="Colossal-Inference API server.")
151160

152161
parser.add_argument("--host", type=str, default="127.0.0.1")
153-
parser.add_argument("--port", type=int, default=8000)
162+
parser.add_argument("--port", type=int, default=8000, help="port of FastAPI server.")
154163
parser.add_argument("--ssl-keyfile", type=str, default=None)
155164
parser.add_argument("--ssl-certfile", type=str, default=None)
156165
parser.add_argument(
@@ -164,6 +173,7 @@ def parse_args():
164173
"specified, the model name will be the same as "
165174
"the huggingface name.",
166175
)
176+
167177
parser.add_argument(
168178
"--chat-template",
169179
type=str,
@@ -184,13 +194,21 @@ def parse_args():
184194
if __name__ == "__main__":
185195
args = parse_args()
186196
inference_config = InferenceConfig.from_dict(vars(args))
187-
model = AutoModelForCausalLM.from_pretrained(args.model)
188197
tokenizer = AutoTokenizer.from_pretrained(args.model)
198+
colossalai_backend_port = find_available_ports(1)[0]
199+
colossalai.launch(
200+
rank=0,
201+
world_size=1,
202+
host=args.host,
203+
port=colossalai_backend_port,
204+
backend="nccl",
205+
)
206+
model = AutoModelForCausalLM.from_pretrained(args.model)
189207
async_engine = AsyncInferenceEngine(
190-
start_engine_loop=True, model=model, tokenizer=tokenizer, inference_config=inference_config
208+
start_engine_loop=True, model_or_path=model, tokenizer=tokenizer, inference_config=inference_config
191209
)
192210
engine = async_engine.engine
193-
completion_serving = CompletionServing(async_engine, served_model=model.__class__.__name__)
211+
completion_serving = CompletionServing(async_engine, model.__class__.__name__)
194212
chat_serving = ChatServing(
195213
async_engine,
196214
served_model=model.__class__.__name__,

colossalai/inference/server/completion_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def create_completion(self, request, generation_config):
2323

2424
# it is not a intuitive way
2525
self.engine.engine.generation_config = generation_config
26-
result_generator = self.engine.generate(request_id, prompt=prompt)
26+
result_generator = self.engine.generate(request_id, prompt=prompt, generation_config=generation_config)
2727

2828
if await request.is_disconnected():
2929
# Abort the request if the client disconnects.

examples/inference/client/run_locust.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
model_path=${1:-"lmsys/vicuna-7b-v1.3"}
77
chat_template="{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}"
88
echo "Model Path: $model_path"
9+
echo "Chat Tempelate" "${chat_template}"
910
echo "Starting server..."
10-
python -m colossalai.inference.server.api_server --model $model_path --chat-template $chat_template &
11+
python -m colossalai.inference.server.api_server --model $model_path --chat-template "${chat_template}" &
1112
SERVER_PID=$!
1213

1314
# waiting time
@@ -17,9 +18,9 @@ sleep 60
1718
echo "Starting Locust..."
1819
echo "The test will automatically begin, you can turn to http://0.0.0.0:8089 for more information."
1920
echo "Test completion api first"
20-
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
21+
locust -f locustfile.py -t 300 --tags online-generation --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10
2122
echo "Test chat api"
22-
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 100 --stop-timeout 10
23+
locust -f locustfile.py -t 300 --tags online-chat --host http://127.0.0.1:8000 --autostart --users 300 --stop-timeout 10
2324
# kill Server
2425
echo "Stopping server..."
2526
kill $SERVER_PID

0 commit comments

Comments
 (0)