diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 3dfbbcf64..3835994e5 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -49,10 +49,12 @@ from lightllm.utils.envs_utils import get_unique_server_name from dataclasses import dataclass -from .api_openai import chat_completions_impl +from .api_openai import chat_completions_impl, completions_impl from .api_models import ( ChatCompletionRequest, ChatCompletionResponse, + CompletionRequest, + CompletionResponse, ) from .build_prompt import build_prompt, init_tokenizer @@ -223,6 +225,12 @@ async def chat_completions(request: ChatCompletionRequest, raw_request: Request) return resp +@app.post("/v1/completions", response_model=CompletionResponse) +async def completions(request: CompletionRequest, raw_request: Request) -> Response: + resp = await completions_impl(request, raw_request) + return resp + + @app.get("/tokens") @app.post("/tokens") async def tokens(request: Request): diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index aab44748c..960185395 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -48,6 +48,32 @@ class ToolChoice(BaseModel): type: Literal["function"] = Field(default="function", examples=["function"]) +class CompletionRequest(BaseModel): + model: str + # prompt: string or tokens + prompt: Union[str, List[str], List[int], List[List[int]]] + suffix: Optional[str] = None + max_tokens: Optional[int] = 16 + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + logprobs: Optional[int] = None + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + best_of: Optional[int] = 1 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + + # Additional parameters supported by LightLLM + do_sample: Optional[bool] = False + top_k: Optional[int] = -1 + repetition_penalty: Optional[float] = 1.0 + ignore_eos: Optional[bool] = False + + class ChatCompletionRequest(BaseModel): model: str messages: List[Message] @@ -148,3 +174,49 @@ class ChatCompletionStreamResponse(BaseModel): @field_validator("id", mode="before") def ensure_id_is_str(cls, v): return str(v) + + +class CompletionLogprobs(BaseModel): + tokens: List[str] = [] + token_logprobs: List[Optional[float]] = [] + top_logprobs: List[Optional[Dict[str, float]]] = [] + text_offset: List[int] = [] + + +class CompletionChoice(BaseModel): + text: str + index: int + logprobs: Optional["CompletionLogprobs"] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{uuid.uuid4().hex}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionChoice] + usage: UsageInfo + + @field_validator("id", mode="before") + def ensure_id_is_str(cls, v): + return str(v) + + +class CompletionStreamChoice(BaseModel): + text: str + index: int + logprobs: Optional[Dict] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{uuid.uuid4().hex}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionStreamChoice] + + @field_validator("id", mode="before") + def ensure_id_is_str(cls, v): + return str(v) diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index d7a9338c0..e7648401c 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -7,6 +7,7 @@ import os from io import BytesIO import pickle +import uuid from .function_call_parser import TOOLS_TAG_LIST, FunctionCallParser from .build_prompt import build_prompt, init_tokenizer @@ -14,10 +15,9 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) import ujson as json from http import HTTPStatus -import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import AsyncGenerator, Union, List, Dict from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -36,6 +36,12 @@ from .api_models import ( ChatCompletionRequest, + CompletionRequest, + CompletionResponse, + CompletionChoice, + CompletionLogprobs, + CompletionStreamResponse, + CompletionStreamChoice, FunctionResponse, ToolCall, UsageInfo, @@ -178,7 +184,9 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req if finish_reason == "stop": finish_reason = "function_call" try: - parser = FunctionCallParser(tools, g_objs.args.tool_call_parser) + # 为 tool_call_parser 提供默认值 + tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3" + parser = FunctionCallParser(tools, tool_parser) full_normal_text, call_info_list = parser.parse_non_stream(text) tool_calls = [ ToolCall( @@ -207,7 +215,7 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req return resp if sampling_params.n != 1: - raise Exception("stream api only support n = 1") + return create_error_response(HTTPStatus.BAD_REQUEST, "stream api only support n = 1") parser_dict = {} @@ -224,9 +232,11 @@ async def stream_results() -> AsyncGenerator[bytes, None]: finish_reason = finish_status.get_finish_reason() if index not in parser_dict: + # 为 tool_call_parser 提供默认值 + tool_parser = getattr(g_objs.args, "tool_call_parser", None) or "llama3" parser_dict[index] = FunctionCallParser( tools=request.tools, - tool_call_parser=g_objs.args.tool_call_parser, + tool_call_parser=tool_parser, ) parser = parser_dict[index] @@ -238,7 +248,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: choice_data = ChatCompletionStreamResponseChoice( index=0, delta=DeltaMessage(content=normal_text), - finish_reason=finish_reason if finish_reason else "", + finish_reason=finish_reason if finish_reason else None, ) chunk = ChatCompletionStreamResponse( id=group_request_id, @@ -304,3 +314,333 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def completions_impl(request: CompletionRequest, raw_request: Request) -> Response: + from .api_http import g_objs + + if request.logit_bias is not None: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "The logit_bias parameter is not currently supported", + ) + + created_time = int(time.time()) + + # Parse and normalize prompts + prompts = [] + if isinstance(request.prompt, list): + if len(request.prompt) == 0: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Prompt cannot be empty", + ) + + # Check if it's a list of integers (token IDs) + if isinstance(request.prompt[0], int): + prompts.append(request.prompt) + elif isinstance(request.prompt[0], list): + for token_list in request.prompt: + prompts.append(token_list) + else: + # List of strings + prompts = request.prompt + else: + # Single string + prompts = [request.prompt] + + # Handle suffix for completion mode + if request.suffix: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "The suffix parameter is not currently supported", + ) + + # Prepare sampling parameters - same as g_generate_stream_func pattern + sampling_params_dict = { + "do_sample": request.do_sample, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + "repetition_penalty": request.repetition_penalty, + "temperature": request.temperature, + "top_p": request.top_p, + "top_k": request.top_k, + "ignore_eos": request.ignore_eos, + "max_new_tokens": request.max_tokens, + "stop_sequences": request.stop, + "n": request.n, + "best_of": request.best_of, + "add_special_tokens": False, + } + + sampling_params = SamplingParams() + sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict) + sampling_params.verify() + + # v1/completions does not support multimodal inputs, so we use an empty MultimodalParams + multimodal_params = MultimodalParams() + + return await _process_prompts_completion( + prompts, sampling_params, sampling_params_dict, multimodal_params, raw_request, request, created_time + ) + + +async def _process_prompts_completion( + prompts: List[str] | List[List[int]], + sampling_params: SamplingParams, + sampling_params_dict: Dict, + multimodal_params: MultimodalParams, + raw_request: Request, + request: CompletionRequest, + created_time: int, +) -> Response: + from .api_http import g_objs + import asyncio + + if request.stream: + if len(prompts) > 1: + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Streaming is not supported for batch requests", + ) + + if sampling_params.n != 1: + return create_error_response(HTTPStatus.BAD_REQUEST, "stream api only support n = 1") + + return await _handle_streaming_completion( + prompts[0], sampling_params, multimodal_params, raw_request, request, created_time + ) + + async def process_single_prompt(prompt: str | List[int], prompt_index: int): + if len(prompts) > 1: + individual_sampling_params = SamplingParams() + individual_sampling_params.init(tokenizer=g_objs.httpserver_manager.tokenizer, **sampling_params_dict) + individual_sampling_params.verify() + else: + individual_sampling_params = sampling_params + + # Convert token array to string for _collect_generation_results + prompt_str = prompt + if isinstance(prompt, list): + prompt_str = g_objs.httpserver_manager.tokenizer.decode(prompt, skip_special_tokens=False) + + generator = g_objs.httpserver_manager.generate( + prompt, individual_sampling_params, multimodal_params, request=raw_request + ) + + return await _collect_generation_results(generator, request, prompt_str, prompt_index) + + tasks = [asyncio.create_task(process_single_prompt(prompt, i)) for i, prompt in enumerate(prompts)] + + results = await asyncio.gather(*tasks) + return _build_completion_response(results, request, created_time, len(prompts) > 1) + + +async def _handle_streaming_completion( + prompt: str | List[int], + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + raw_request: Request, + request: CompletionRequest, + created_time: int, +) -> Response: + from .api_http import g_objs + + results_generator = g_objs.httpserver_manager.generate( + prompt, sampling_params, multimodal_params, request=raw_request + ) + + async def stream_results() -> AsyncGenerator[bytes, None]: + from .req_id_generator import convert_sub_id_to_group_id + + async for sub_req_id, request_output, metadata, finish_status in results_generator: + group_request_id = convert_sub_id_to_group_id(sub_req_id) + + current_finish_reason = None + if finish_status.is_finished(): + current_finish_reason = finish_status.get_finish_reason() + + output_text = request_output + if request.echo and metadata.get("is_first_token", False): + prompt_str = prompt + if isinstance(prompt, list): + prompt_str = g_objs.httpserver_manager.tokenizer.decode(prompt, skip_special_tokens=False) + output_text = prompt_str + output_text + + stream_choice = CompletionStreamChoice( + index=0, + text=output_text, + finish_reason=current_finish_reason, + logprobs=None if request.logprobs is None else {}, + ) + stream_resp = CompletionStreamResponse( + id=group_request_id, + created=created_time, + model=request.model, + choices=[stream_choice], + ) + yield ("data: " + json.dumps(stream_resp.dict(), ensure_ascii=False) + "\n\n").encode("utf-8") + + yield "data: [DONE]\n\n".encode("utf-8") + + background_tasks = BackgroundTasks() + return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def _collect_generation_results(generator, request: CompletionRequest, prompt: str, prompt_index: int): + final_output = [] + count_output_tokens = 0 + finish_reason = None + prompt_tokens = 0 + token_infos = [] if request.logprobs is not None else None + prompt_logprobs = None + prompt_token_ids = None + is_first_metadata = True + + async for sub_req_id, request_output, metadata, finish_status in generator: + if is_first_metadata: + prompt_logprobs = metadata.get("prompt_logprobs", None) + prompt_token_ids = metadata.get("prompt_token_ids", None) + is_first_metadata = False + + count_output_tokens += 1 + final_output.append(request_output) + + if request.logprobs is not None and token_infos is not None: + token_info = { + "text": request_output, + "logprob": metadata.get("logprob", None), + "id": metadata.get("id", None), + } + token_infos.append(token_info) + + if finish_status.is_finished(): + finish_reason = finish_status.get_finish_reason() + prompt_tokens = metadata["prompt_tokens"] + + return { + "index": prompt_index, + "text": "".join(final_output), + "finish_reason": finish_reason, + "prompt_tokens": prompt_tokens, + "completion_tokens": count_output_tokens, + "token_infos": token_infos, + "prompt_logprobs": prompt_logprobs, + "prompt_token_ids": prompt_token_ids, + "prompt_text": prompt, + } + + +def _build_completion_response(results: List[Dict], request: CompletionRequest, created_time: int, is_batch: bool): + from .api_http import g_objs + + choices = [] + total_prompt_tokens = 0 + total_completion_tokens = 0 + + for result in results: + text = result["text"] + if request.echo: + text = result["prompt_text"] + text + + logprobs_data = _build_logprobs_data(result, request, g_objs.httpserver_manager.tokenizer) + + choice = CompletionChoice( + index=result["index"], + text=text, + finish_reason=result["finish_reason"], + logprobs=CompletionLogprobs(**logprobs_data) if logprobs_data else None, + ) + choices.append(choice) + + total_prompt_tokens += result["prompt_tokens"] + total_completion_tokens += result["completion_tokens"] + + usage = UsageInfo( + prompt_tokens=total_prompt_tokens, + completion_tokens=total_completion_tokens, + total_tokens=total_prompt_tokens + total_completion_tokens, + ) + + if is_batch: + group_request_id = f"cmpl-batch-{uuid.uuid4().hex[:8]}" + else: + group_request_id = f"cmpl-{uuid.uuid4().hex[:8]}" + + return CompletionResponse( + id=group_request_id, created=created_time, model=request.model, choices=choices, usage=usage + ) + + +def _build_logprobs_data(result: Dict, request: CompletionRequest, tokenizer) -> Dict: + if request.logprobs is None: + return None + + all_tokens = [] + all_token_logprobs = [] + all_text_offsets = [] + offset = 0 + + def add_tokens_to_logprobs(token_ids=None, token_infos=None, logprob_map=None): + nonlocal offset + + def add_single_token(token_text: str, logprob: float): + nonlocal offset + all_tokens.append(token_text) + all_token_logprobs.append(logprob) + all_text_offsets.append(offset) + offset += len(token_text) + + if token_ids is not None: + for token_id in token_ids: + token_text = tokenizer.decode([token_id], skip_special_tokens=False) + logprob = logprob_map.get(token_id, None) if logprob_map else None + add_single_token(token_text, logprob) + elif token_infos is not None: + for token_info in token_infos: + add_single_token(token_info["text"], token_info["logprob"]) + + # 处理 echo 模式下的 prompt tokens + if request.echo and result.get("prompt_logprobs") is not None: + prompt_logprobs = result["prompt_logprobs"] + prompt_token_ids = result.get("prompt_token_ids") + + # 创建 token_id 到 logprob 的映射 + logprob_map = {} + for current_token_id, logprobs_dict in prompt_logprobs: + for next_token_id, logprob in logprobs_dict.items(): + logprob_map[int(next_token_id)] = logprob + + # 处理所有 prompt tokens + if prompt_token_ids is not None: + add_tokens_to_logprobs(token_ids=prompt_token_ids, logprob_map=logprob_map) + + elif request.echo: + # echo=True 但没有 prompt logprobs + prompt_token_ids = result.get("prompt_token_ids") + if prompt_token_ids is not None: + add_tokens_to_logprobs(token_ids=prompt_token_ids) + else: + # 回退:重新 tokenize prompt + prompt_tokens = tokenizer.encode(result["prompt_text"], add_special_tokens=False) + add_tokens_to_logprobs(token_ids=prompt_tokens) + + # 添加生成的 tokens 和 logprobs + if result.get("token_infos"): + add_tokens_to_logprobs(token_infos=result["token_infos"]) + + top_logprobs_list = [] + for i, (token, logprob) in enumerate(zip(all_tokens, all_token_logprobs)): + if logprob is not None: + # TODO: 标准实现需要从后端获取top-k个logprobs数据 + # 目前后端不支持,只能获取所选token的logprobs + top_logprobs_list.append({token: logprob}) + else: + top_logprobs_list.append(None) + + return { + "tokens": all_tokens, + "token_logprobs": all_token_logprobs, + "top_logprobs": top_logprobs_list, + "text_offset": all_text_offsets, + } diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 4e1971ef9..f2ebadad1 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -229,6 +229,8 @@ def get_all_prompt_metadata(self): """ return_all_prompt_logprobs mode use to return all logprobs cacul ppl """ + if hasattr(self, "_cache_prompt_metadata"): + return self._cache_prompt_metadata metadata = {} cur_ids = self.shm_prompt_ids.arr[0 : self.input_len] all_prompts = [] @@ -238,6 +240,7 @@ def get_all_prompt_metadata(self): metadata["prompt_logprobs"] = all_prompts metadata["prompt_token_ids"] = [int(e) for e in cur_ids] + self._cache_prompt_metadata = metadata return metadata diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py index 1462ca84e..451402180 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py @@ -1,6 +1,6 @@ import torch from .impl import ContinuesBatchBackend -from typing import List, Tuple +from typing import List, Tuple, Callable, Optional from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams, g_infer_context from lightllm.server.router.model_infer.mode_backend.pre import prepare_prefill_inputs @@ -14,67 +14,60 @@ def __init__(self) -> None: def prefill(self, run_reqs: List[Tuple]): # 在 return all_prompt_logprobs 的模式下,不能启用 dynamic prompt cache assert self.radix_cache is None - req_ids = self._init_reqs(run_reqs, init_req_obj=True) + self._init_reqs(run_reqs, init_req_obj=False) + return - req_objs = self._trans_req_ids_to_req_objs(req_ids) + def normal_prefill_reqs( + self, + prefill_reqs: List[InferReq], + uninit_reqs: List[InferReq], + ok_finished_reqs: List[InferReq], + mask_func: Optional[Callable[[List[InferReq], torch.Tensor], None]] = None, + extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None, + ): model_input, run_reqs = prepare_prefill_inputs( - req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal + prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal ) model_output = self.model.forward(model_input) prompt_all_logits = model_output.logits + + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) + input_ids = model_input.input_ids b_ready_cache_len = model_input.b_ready_cache_len b_seq_len = model_input.b_seq_len last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1 logits = prompt_all_logits[last_index, :] - next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) - next_token_ids = next_token_ids.detach().cpu().numpy() - next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - b_q_seq_len = b_seq_len - b_ready_cache_len b_start_loc = torch.cumsum(b_q_seq_len, dim=0, dtype=torch.long) - b_q_seq_len b_start_loc = b_start_loc.cpu().numpy() b_q_seq_len = b_q_seq_len.cpu().numpy() - finished_req_ids = [] - for req_obj, next_token_id, next_token_logprob, start_loc, q_seq_len in zip( - run_reqs, next_token_ids, next_token_logprobs, b_start_loc, b_q_seq_len - ): - # prefill and decode is same + for req_obj, start_loc, q_seq_len in zip(run_reqs, b_start_loc, b_q_seq_len): req_obj: InferReq = req_obj - req_obj.cur_kv_len = req_obj.get_cur_total_len() - - req_obj.set_next_gen_token_id(next_token_id, next_token_logprob) - req_obj.cur_output_len += 1 - - # 填充 logprobs 信息 cur_ids: torch.Tensor = input_ids[start_loc : start_loc + q_seq_len] cur_logits = prompt_all_logits[start_loc : start_loc + q_seq_len] cur_logprobs = torch.log_softmax(cur_logits, dim=-1, dtype=torch.float)[0:-1, :] cur_logprobs = torch.gather(cur_logprobs, dim=1, index=cur_ids[1:].view(-1, 1)).detach().cpu().numpy() - for i in range(req_obj.shm_req.input_len - 1): - req_obj.shm_req.shm_logprobs.arr[i + 1] = cur_logprobs[i] - - req_obj.update_finish_status(self.eos_id) - - if req_obj.finish_status.is_finished() or req_obj.shm_req.router_aborted: - finished_req_ids.append(req_obj.shm_req.request_id) - - if self.is_master_in_dp: - # shm_cur_kv_len shm_cur_output_len 是 router 调度进程需要读的信息 - # finish_token_index finish_status candetoken_out_len 是 - # detokenization 进程需要的信息,注意这些变量的写入顺序避免异步协同问题。 - req_obj.shm_req.shm_cur_kv_len = req_obj.cur_kv_len - req_obj.shm_req.shm_cur_output_len = req_obj.cur_output_len + if req_obj.shm_req.input_len > 1: + req_obj.shm_req.shm_logprobs.arr[1 : req_obj.shm_req.input_len] = cur_logprobs.flatten() - if req_obj.finish_status.is_finished(): - req_obj.shm_req.finish_token_index = req_obj.get_cur_total_len() - 1 - req_obj.shm_req.finish_status = req_obj.finish_status + if mask_func is not None: + mask_func(run_reqs, logits) - req_obj.shm_req.candetoken_out_len = req_obj.cur_output_len + next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id) + next_token_ids = next_token_ids.detach().cpu().numpy() + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() - g_infer_context.filter(finished_request_ids=finished_req_ids) + self._post_handle( + run_reqs, + next_token_ids, + next_token_logprobs, + is_chuncked_mode=False, + do_filter_finished_reqs=False, + extra_post_req_handle_func=extra_post_req_handle_func, + ) return diff --git a/test/test_api/test_openai_api.py b/test/test_api/test_openai_api.py index 6d98dadbe..428181d74 100644 --- a/test/test_api/test_openai_api.py +++ b/test/test_api/test_openai_api.py @@ -67,6 +67,53 @@ def stream_chat(self, message: str, **kwargs): else: raise Exception(f"API调用失败: {response.status_code} - {response.text}") + def completions(self, prompt: str, **kwargs) -> Dict[str, Any]: + """文本补全""" + data = { + "model": self.model_name, + "prompt": prompt, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 100), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def stream_completions(self, prompt: str, **kwargs): + """流式文本补全""" + data = { + "model": self.model_name, + "prompt": prompt, + "stream": True, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 100), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data, stream=True) + + if response.status_code == 200: + for line in response.iter_lines(): + if line: + line = line.decode("utf-8") + if line.startswith("data: "): + data_str = line[6:] + if data_str == "[DONE]": + break + try: + chunk = json.loads(data_str) + if chunk["choices"][0].get("text"): + yield chunk["choices"][0]["text"] + except json.JSONDecodeError: + continue + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + def function_call(self, message: str, tools: List[Dict], tool_choice: str = "auto", **kwargs) -> Dict[str, Any]: """Function calling""" data = { @@ -132,6 +179,201 @@ def stream_function_call(self, message: str, tools: List[Dict], tool_choice: str else: raise Exception(f"API调用失败: {response.status_code} - {response.text}") + def completions_with_tokens(self, token_ids: List[int], **kwargs) -> Dict[str, Any]: + """使用token数组进行文本补全""" + data = { + "model": self.model_name, + "prompt": token_ids, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 100), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_with_multiple_prompts(self, prompts: List[str], **kwargs) -> Dict[str, Any]: + """使用多个prompt进行文本补全(只处理第一个)""" + data = { + "model": self.model_name, + "prompt": prompts, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 100), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_with_logprobs(self, prompt: str, logprobs: int = 5, **kwargs) -> Dict[str, Any]: + """测试带logprobs的文本补全""" + data = { + "model": self.model_name, + "prompt": prompt, + "logprobs": logprobs, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 50), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_with_echo(self, prompt: str, echo: bool = True, **kwargs) -> Dict[str, Any]: + """测试带echo参数的文本补全""" + data = { + "model": self.model_name, + "prompt": prompt, + "echo": echo, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 30), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_with_echo_and_logprobs( + self, prompt: str, echo: bool = True, logprobs: int = 5, **kwargs + ) -> Dict[str, Any]: + """测试带echo和logprobs参数的文本补全(重点测试修复后的功能)""" + data = { + "model": self.model_name, + "prompt": prompt, + "echo": echo, + "logprobs": logprobs, + "temperature": kwargs.get("temperature", 0.0), # 使用0温度以获得一致结果 + "max_tokens": kwargs.get("max_tokens", 20), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_logprobs_structure_test(self, prompt: str, **kwargs) -> Dict[str, Any]: + """专门测试logprobs数据结构的完整性""" + data = { + "model": self.model_name, + "prompt": prompt, + "logprobs": 3, + "echo": True, + "temperature": 0.0, + "max_tokens": 10, + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_with_n(self, prompt: str, n: int = 2, **kwargs) -> Dict[str, Any]: + """测试n参数生成多个候选""" + data = { + "model": self.model_name, + "prompt": prompt, + "n": n, + "best_of": n, # LightLLM要求n == best_of + "temperature": kwargs.get("temperature", 0.8), + "max_tokens": kwargs.get("max_tokens", 10), # 确保max_tokens至少为1 + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_with_stop(self, prompt: str, stop, **kwargs) -> Dict[str, Any]: + """测试stop参数""" + data = { + "model": self.model_name, + "prompt": prompt, + "stop": stop, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 50), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + def completions_with_multiple_token_arrays(self, token_arrays: List[List[int]], **kwargs) -> Dict[str, Any]: + """测试多个token数组的批处理""" + data = { + "model": self.model_name, + "prompt": token_arrays, + "temperature": kwargs.get("temperature", 0.7), + "max_tokens": kwargs.get("max_tokens", 30), + **kwargs, + } + + response = requests.post(f"{self.base_url}/v1/completions", headers=self.headers, json=data) + + if response.status_code == 200: + return response.json() + else: + raise Exception(f"API调用失败: {response.status_code} - {response.text}") + + +def test_completions(): + """测试文本补全API""" + client = LightLLMClient() + + try: + print("=== 测试文本补全 ===") + result = client.completions("The capital of France is", max_tokens=50) + print("提示: The capital of France is") + print("补全:", result["choices"][0]["text"]) + print(f"用量: {result['usage']}") + print() + except Exception as e: + print(f"错误: {e}") + + +def test_stream_completions(): + """测试流式文本补全API""" + client = LightLLMClient() + + try: + print("=== 测试流式文本补全 ===") + print("提示: Once upon a time") + print("补全: ", end="", flush=True) + + for chunk in client.stream_completions("Once upon a time", max_tokens=100): + print(chunk, end="", flush=True) + print("\n") + except Exception as e: + print(f"错误: {e}") + def test_simple_chat(): client = LightLLMClient() @@ -266,12 +508,150 @@ def test_stream_function_call(): print(f"错误: {e}") +def test_token_completions(): + """测试使用token数组的文本补全API""" + client = LightLLMClient() + + try: + print("=== 测试token数组补全 ===") + # 示例token数组 (这些是示例值,实际应该用正确的tokenizer) + token_ids = [2701, 525, 5248, 5754, 4755] # 示例token + result = client.completions_with_tokens(token_ids, max_tokens=50) + print(f"Token IDs: {token_ids}") + print("补全:", result["choices"][0]["text"]) + print(f"用量: {result['usage']}") + print() + except Exception as e: + print(f"错误: {e}") + + +def test_multiple_prompts(): + """测试多个prompt的文本补全API(真正的批处理)""" + client = LightLLMClient() + + try: + print("=== 测试批处理补全 ===") + prompts = ["Hello, how are you?", "What is the weather like?", "Tell me a joke"] + result = client.completions_with_multiple_prompts(prompts, max_tokens=30) + print(f"发送了 {len(prompts)} 个prompts进行批处理:") + + for i, choice in enumerate(result["choices"]): + print(f" {i+1}. 提示: {prompts[choice['index']]}") + print(f" 补全: {choice['text'].strip()}") + print(f" 完成原因: {choice['finish_reason']}") + + print(f"总用量: {result['usage']}") + print() + except Exception as e: + print(f"错误: {e}") + + +def test_logprobs(): + """测试logprobs功能""" + client = LightLLMClient() + + try: + print("=== 测试logprobs ===") + result = client.completions_with_logprobs("The capital of France is", logprobs=5, max_tokens=20) + print("提示: The capital of France is") + print("补全:", result["choices"][0]["text"]) + + # 检查logprobs结构 + logprobs = result["choices"][0]["logprobs"] + if logprobs: + print("Logprobs结构:") + print(f" tokens: {logprobs.get('tokens', [])[:5]}...") # 只显示前5个 + print(f" token_logprobs: {logprobs.get('token_logprobs', [])[:5]}...") + print(f" text_offset: {logprobs.get('text_offset', [])[:5]}...") + print(f" top_logprobs: {logprobs.get('top_logprobs', [])[:2]}...") # 只显示前2个 + print() + except Exception as e: + print(f"错误: {e}") + + +def test_echo(): + """测试echo参数""" + client = LightLLMClient() + + try: + print("=== 测试echo参数 ===") + + # 测试echo=True + result = client.completions_with_echo("Hello world", echo=True, max_tokens=20) + print("提示: Hello world (echo=True)") + print("补全:", repr(result["choices"][0]["text"])) + print() + + # 测试echo=False + result = client.completions_with_echo("Hello world", echo=False, max_tokens=20) + print("提示: Hello world (echo=False)") + print("补全:", repr(result["choices"][0]["text"])) + print() + except Exception as e: + print(f"错误: {e}") + + +def test_stop_parameter(): + """测试stop参数""" + client = LightLLMClient() + + try: + print("=== 测试stop参数 ===") + + # 测试单个stop字符串 + result = client.completions_with_stop("Count: 1, 2, 3, 4", stop="12", max_tokens=50) + print("提示: Count: 1, 2, 3, 4 (stop='12')") + print("补全:", repr(result["choices"][0]["text"])) + print("完成原因:", result["choices"][0]["finish_reason"]) + print() + + # 测试多个stop字符串 + result = client.completions_with_stop("The colors are red, blue, green", stop=["red", "blue"], max_tokens=50) + print("提示: The colors are red, blue, green (stop=['red', 'blue'])") + print("补全:", repr(result["choices"][0]["text"])) + print("完成原因:", result["choices"][0]["finish_reason"]) + print() + except Exception as e: + print(f"错误: {e}") + + +def test_multiple_token_arrays(): + """测试多个token数组的批处理""" + client = LightLLMClient() + + try: + print("=== 测试多个token数组批处理 ===") + token_arrays = [[2701, 525, 5248], [4755, 8394, 1234], [9876, 5432, 1098]] + + result = client.completions_with_multiple_token_arrays(token_arrays, max_tokens=20) + print(f"发送了 {len(token_arrays)} 个token数组进行批处理:") + + for i, choice in enumerate(result["choices"]): + print(f" {i+1}. Token数组: {token_arrays[choice['index']]}") + print(f" 补全: {choice['text'].strip()}") + print(f" 完成原因: {choice['finish_reason']}") + print() + except Exception as e: + print(f"错误: {e}") + + def main(): + # 基础功能测试 + test_completions() + test_stream_completions() test_simple_chat() test_stream_chat() test_function_call() test_stream_function_call() + # 高级功能测试 + test_token_completions() + test_multiple_prompts() + test_multiple_token_arrays() + test_logprobs() + test_echo() + test_stop_parameter() + if __name__ == "__main__": main()