Skip to content

Commit 77aa4df

Browse files
authored
Revert "add best_of and use_beam_search for completions interface" (#2401)
1 parent 13f40b3 commit 77aa4df

File tree

4 files changed

+26
-79
lines changed

4 files changed

+26
-79
lines changed

fastchat/protocol/api_protocol.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ class CompletionResponse(BaseModel):
150150
created: int = Field(default_factory=lambda: int(time.time()))
151151
model: str
152152
choices: List[CompletionResponseChoice]
153-
usage: Union[UsageInfo, List[UsageInfo]]
153+
usage: UsageInfo
154154

155155

156156
class CompletionResponseStreamChoice(BaseModel):

fastchat/protocol/openai_api_protocol.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,11 @@ class CompletionRequest(BaseModel):
151151
presence_penalty: Optional[float] = 0.0
152152
frequency_penalty: Optional[float] = 0.0
153153
user: Optional[str] = None
154-
use_beam_search: Optional[bool] = False
155-
best_of: Optional[int] = None
156154

157155

158156
class CompletionResponseChoice(BaseModel):
159157
index: int
160-
text: Union[str, List[str]]
158+
text: str
161159
logprobs: Optional[int] = None
162160
finish_reason: Optional[Literal["stop", "length"]] = None
163161

fastchat/serve/openai_api_server.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -241,9 +241,6 @@ async def get_gen_params(
241241
max_tokens: Optional[int],
242242
echo: Optional[bool],
243243
stop: Optional[Union[str, List[str]]],
244-
best_of: Optional[int] = None,
245-
n: Optional[int] = 1,
246-
use_beam_search: Optional[bool] = None,
247244
) -> Dict[str, Any]:
248245
conv = await get_conv(model_name, worker_addr)
249246
conv = Conversation(
@@ -290,11 +287,6 @@ async def get_gen_params(
290287
"stop_token_ids": conv.stop_token_ids,
291288
}
292289

293-
if best_of is not None:
294-
gen_params.update({"n": n, "best_of": best_of})
295-
if use_beam_search is not None:
296-
gen_params.update({"use_beam_search": use_beam_search})
297-
298290
new_stop = set()
299291
_add_to_set(stop, new_stop)
300292
_add_to_set(conv.stop_str, new_stop)
@@ -502,18 +494,12 @@ async def create_completion(request: CompletionRequest):
502494
max_tokens=request.max_tokens,
503495
echo=request.echo,
504496
stop=request.stop,
505-
best_of=request.best_of,
506-
n=request.n,
507-
use_beam_search=request.use_beam_search,
508497
)
509498
for i in range(request.n):
510499
content = asyncio.create_task(
511500
generate_completion(gen_params, worker_addr)
512501
)
513502
text_completions.append(content)
514-
# when use with best_of, only need send one request
515-
if request.best_of:
516-
break
517503

518504
try:
519505
all_tasks = await asyncio.gather(*text_completions)
@@ -533,18 +519,9 @@ async def create_completion(request: CompletionRequest):
533519
finish_reason=content.get("finish_reason", "stop"),
534520
)
535521
)
536-
idx = 0
537-
while True:
538-
info = content["usage"]
539-
if isinstance(info, list):
540-
info = info[idx]
541-
542-
task_usage = UsageInfo.parse_obj(info)
543-
544-
for usage_key, usage_value in task_usage.dict().items():
545-
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
546-
idx += 1
547-
break
522+
task_usage = UsageInfo.parse_obj(content["usage"])
523+
for usage_key, usage_value in task_usage.dict().items():
524+
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
548525

549526
return CompletionResponse(
550527
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)

fastchat/serve/vllm_worker.py

Lines changed: 21 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from vllm.sampling_params import SamplingParams
1919
from vllm.utils import random_uuid
2020

21-
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
2221
from fastchat.serve.model_worker import (
2322
BaseModelWorker,
2423
logger,
@@ -75,9 +74,6 @@ async def generate_stream(self, params):
7574
if self.tokenizer.eos_token_id is not None:
7675
stop_token_ids.append(self.tokenizer.eos_token_id)
7776
echo = params.get("echo", True)
78-
use_beam_search = params.get("use_beam_search", False)
79-
best_of = params.get("best_of", None)
80-
n = params.get("n", 1)
8177

8278
# Handle stop_str
8379
stop = set()
@@ -94,51 +90,27 @@ async def generate_stream(self, params):
9490
top_p = max(top_p, 1e-5)
9591
if temperature <= 1e-5:
9692
top_p = 1.0
97-
try:
98-
sampling_params = SamplingParams(
99-
n=n,
100-
temperature=temperature,
101-
top_p=top_p,
102-
use_beam_search=use_beam_search,
103-
stop=list(stop),
104-
max_tokens=max_new_tokens,
105-
best_of=best_of,
106-
)
107-
108-
results_generator = engine.generate(context, sampling_params, request_id)
109-
110-
async for request_output in results_generator:
111-
prompt = request_output.prompt
112-
prompt_tokens = len(request_output.prompt_token_ids)
113-
output_usage = []
114-
for out in request_output.outputs:
115-
completion_tokens = len(out.token_ids)
116-
total_tokens = prompt_tokens + completion_tokens
117-
output_usage.append(
118-
{
119-
"prompt_tokens": prompt_tokens,
120-
"completion_tokens": completion_tokens,
121-
"total_tokens": total_tokens,
122-
}
123-
)
124-
125-
if echo:
126-
text_outputs = [
127-
prompt + output.text for output in request_output.outputs
128-
]
129-
else:
130-
text_outputs = [output.text for output in request_output.outputs]
131-
132-
if sampling_params.best_of is None:
133-
text_outputs = [" ".join(text_outputs)]
134-
ret = {"text": text_outputs, "error_code": 0, "usage": output_usage}
135-
yield (json.dumps(ret) + "\0").encode()
136-
except (ValueError, RuntimeError) as e:
137-
ret = {
138-
"text": f"{e}",
139-
"error_code": ErrorCode.PARAM_OUT_OF_RANGE,
140-
"usage": {},
141-
}
93+
sampling_params = SamplingParams(
94+
n=1,
95+
temperature=temperature,
96+
top_p=top_p,
97+
use_beam_search=False,
98+
stop=list(stop),
99+
max_tokens=max_new_tokens,
100+
)
101+
results_generator = engine.generate(context, sampling_params, request_id)
102+
103+
async for request_output in results_generator:
104+
prompt = request_output.prompt
105+
if echo:
106+
text_outputs = [
107+
prompt + output.text for output in request_output.outputs
108+
]
109+
else:
110+
text_outputs = [output.text for output in request_output.outputs]
111+
text_outputs = " ".join(text_outputs)
112+
# Note: usage is not supported yet
113+
ret = {"text": text_outputs, "error_code": 0, "usage": {}}
142114
yield (json.dumps(ret) + "\0").encode()
143115

144116
async def generate(self, params):

0 commit comments

Comments
 (0)