20
20
from fastapi .responses import JSONResponse , Response , StreamingResponse
21
21
from transformers import AutoModelForCausalLM , AutoTokenizer
22
22
23
+ import colossalai
23
24
from colossalai .inference .config import InferenceConfig
24
25
from colossalai .inference .server .chat_service import ChatServing
25
26
from colossalai .inference .server .completion_service import CompletionServing
26
27
from colossalai .inference .server .utils import id_generator
28
+ from colossalai .inference .utils import find_available_ports
27
29
28
30
from colossalai .inference .core .async_engine import AsyncInferenceEngine , InferenceEngine # noqa
29
31
@@ -54,8 +56,9 @@ async def generate(request: Request) -> Response:
54
56
"""
55
57
request_dict = await request .json ()
56
58
prompt = request_dict .pop ("prompt" )
57
- stream = request_dict .pop ("stream" , "false" ).lower ()
58
-
59
+ stream = request_dict .pop ("stream" , "false" )
60
+ if isinstance (stream , str ):
61
+ stream = stream .lower ()
59
62
request_id = id_generator ()
60
63
generation_config = get_generation_config (request_dict )
61
64
results = engine .generate (request_id , prompt , generation_config = generation_config )
@@ -66,7 +69,7 @@ def stream_results():
66
69
ret = {"text" : request_output [len (prompt ) :]}
67
70
yield (json .dumps (ret ) + "\0 " ).encode ("utf-8" )
68
71
69
- if stream == "true" :
72
+ if stream == "true" or stream == True :
70
73
return StreamingResponse (stream_results ())
71
74
72
75
# Non-streaming case
@@ -86,12 +89,14 @@ def stream_results():
86
89
@app .post ("/completion" )
87
90
async def create_completion (request : Request ):
88
91
request_dict = await request .json ()
89
- stream = request_dict .pop ("stream" , "false" ).lower ()
92
+ stream = request_dict .pop ("stream" , "false" )
93
+ if isinstance (stream , str ):
94
+ stream = stream .lower ()
90
95
generation_config = get_generation_config (request_dict )
91
96
result = await completion_serving .create_completion (request , generation_config )
92
97
93
98
ret = {"request_id" : result .request_id , "text" : result .output }
94
- if stream == "true" :
99
+ if stream == "true" or stream == True :
95
100
return StreamingResponse (content = json .dumps (ret ) + "\0 " , media_type = "text/event-stream" )
96
101
else :
97
102
return JSONResponse (content = ret )
@@ -101,10 +106,12 @@ async def create_completion(request: Request):
101
106
async def create_chat (request : Request ):
102
107
request_dict = await request .json ()
103
108
104
- stream = request_dict .get ("stream" , "false" ).lower ()
109
+ stream = request_dict .get ("stream" , "false" )
110
+ if isinstance (stream , str ):
111
+ stream = stream .lower ()
105
112
generation_config = get_generation_config (request_dict )
106
113
message = await chat_serving .create_chat (request , generation_config )
107
- if stream == "true" :
114
+ if stream == "true" or stream == True :
108
115
return StreamingResponse (content = message , media_type = "text/event-stream" )
109
116
else :
110
117
ret = {"role" : message .role , "text" : message .content }
@@ -115,27 +122,29 @@ def get_generation_config(request):
115
122
generation_config = async_engine .engine .generation_config
116
123
for arg in request :
117
124
if hasattr (generation_config , arg ):
118
- generation_config [ arg ] = request [arg ]
125
+ setattr ( generation_config , arg , request [arg ])
119
126
return generation_config
120
127
121
128
122
129
def add_engine_config (parser ):
123
- parser .add_argument ("--model" , type = str , default = "llama2-7b" , help = "name or path of the huggingface model to use" )
124
-
125
130
parser .add_argument (
126
- "--max-model-len" ,
127
- type = int ,
128
- default = None ,
129
- help = "model context length. If unspecified, " "will be automatically derived from the model." ,
131
+ "-m" , "--model" , type = str , default = "llama2-7b" , help = "name or path of the huggingface model to use"
130
132
)
131
- # Parallel arguments
132
- parser .add_argument ("--tensor-parallel-size" , "-tp" , type = int , default = 1 , help = "number of tensor parallel replicas" )
133
+ # Parallel arguments not supported now
133
134
134
135
# KV cache arguments
135
136
parser .add_argument ("--block-size" , type = int , default = 16 , choices = [8 , 16 , 32 ], help = "token block size" )
136
137
137
138
parser .add_argument ("--max_batch_size" , type = int , default = 8 , help = "maximum number of batch size" )
138
139
140
+ parser .add_argument ("-i" , "--max_input_len" , type = int , default = 128 , help = "max input length" )
141
+
142
+ parser .add_argument ("-o" , "--max_output_len" , type = int , default = 128 , help = "max output length" )
143
+
144
+ parser .add_argument ("-d" , "--dtype" , type = str , default = "fp16" , help = "Data type" , choices = ["fp16" , "fp32" , "bf16" ])
145
+
146
+ parser .add_argument ("--use_cuda_kernel" , action = "store_true" , help = "Use CUDA kernel, use Triton by default" )
147
+
139
148
# generation arguments
140
149
parser .add_argument (
141
150
"--prompt_template" ,
@@ -150,7 +159,7 @@ def parse_args():
150
159
parser = argparse .ArgumentParser (description = "Colossal-Inference API server." )
151
160
152
161
parser .add_argument ("--host" , type = str , default = "127.0.0.1" )
153
- parser .add_argument ("--port" , type = int , default = 8000 )
162
+ parser .add_argument ("--port" , type = int , default = 8000 , help = "port of FastAPI server." )
154
163
parser .add_argument ("--ssl-keyfile" , type = str , default = None )
155
164
parser .add_argument ("--ssl-certfile" , type = str , default = None )
156
165
parser .add_argument (
@@ -164,6 +173,7 @@ def parse_args():
164
173
"specified, the model name will be the same as "
165
174
"the huggingface name." ,
166
175
)
176
+
167
177
parser .add_argument (
168
178
"--chat-template" ,
169
179
type = str ,
@@ -184,13 +194,21 @@ def parse_args():
184
194
if __name__ == "__main__" :
185
195
args = parse_args ()
186
196
inference_config = InferenceConfig .from_dict (vars (args ))
187
- model = AutoModelForCausalLM .from_pretrained (args .model )
188
197
tokenizer = AutoTokenizer .from_pretrained (args .model )
198
+ colossalai_backend_port = find_available_ports (1 )[0 ]
199
+ colossalai .launch (
200
+ rank = 0 ,
201
+ world_size = 1 ,
202
+ host = args .host ,
203
+ port = colossalai_backend_port ,
204
+ backend = "nccl" ,
205
+ )
206
+ model = AutoModelForCausalLM .from_pretrained (args .model )
189
207
async_engine = AsyncInferenceEngine (
190
- start_engine_loop = True , model = model , tokenizer = tokenizer , inference_config = inference_config
208
+ start_engine_loop = True , model_or_path = model , tokenizer = tokenizer , inference_config = inference_config
191
209
)
192
210
engine = async_engine .engine
193
- completion_serving = CompletionServing (async_engine , served_model = model .__class__ .__name__ )
211
+ completion_serving = CompletionServing (async_engine , model .__class__ .__name__ )
194
212
chat_serving = ChatServing (
195
213
async_engine ,
196
214
served_model = model .__class__ .__name__ ,
0 commit comments