Skip to content

Commit e56abe0

Browse files
committed
fix(vllm-serve): allow null logprobs in responses
1 parent 4480cbd commit e56abe0

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

tests/test_vllm_client_server.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,22 @@ def test_chat(self):
179179
for seq in completion_ids:
180180
assert all(isinstance(tok, int) for tok in seq)
181181

182+
def test_generate_with_logprobs_none(self):
183+
outputs = self.client.generate(["Hello, AI!"], logprobs=None)
184+
185+
assert isinstance(outputs["prompt_ids"], list)
186+
assert isinstance(outputs["completion_ids"], list)
187+
assert outputs["logprobs"] is None
188+
assert outputs["logprob_token_ids"] is None
189+
190+
def test_chat_with_logprobs_none(self):
191+
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)
192+
193+
assert isinstance(outputs["prompt_ids"], list)
194+
assert isinstance(outputs["completion_ids"], list)
195+
assert outputs["logprobs"] is None
196+
assert outputs["logprob_token_ids"] is None
197+
182198
def test_generate_with_params(self):
183199
prompts = ["Hello, AI!", "Tell me a joke"]
184200
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[

trl/scripts/vllm_serve.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,8 @@ class GenerateRequest(BaseModel):
506506
class GenerateResponse(BaseModel):
507507
prompt_ids: list[list[int]]
508508
completion_ids: list[list[int]]
509-
logprobs: list[list[list[float]]]
510-
logprob_token_ids: list[list[list[int]]]
509+
logprobs: list[list[list[float | None]]] | None
510+
logprob_token_ids: list[list[list[int]]] | None
511511

512512
@app.post("/generate/", response_model=GenerateResponse)
513513
async def generate(request: GenerateRequest):
@@ -672,8 +672,8 @@ class ChatRequest(BaseModel):
672672
class ChatResponse(BaseModel):
673673
prompt_ids: list[list[int]]
674674
completion_ids: list[list[int]]
675-
logprobs: list[list[list[float]]]
676-
logprob_token_ids: list[list[list[int]]]
675+
logprobs: list[list[list[float | None]]] | None
676+
logprob_token_ids: list[list[list[int]]] | None
677677

678678
@app.post("/chat/", response_model=ChatResponse)
679679
async def chat(request: ChatRequest):

0 commit comments

Comments
 (0)