Skip to content

Commit 348a547

Browse files
Update logprobs_mode definition
Signed-off-by: Yuan Tong <13075180+tongyuantongyu@users.noreply.github.com>
1 parent a6a1356 commit 348a547

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

tensorrt_llm/sampling_params.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass, field, fields
5-
from typing import List, Literal, NamedTuple, Optional, Tuple, Union
5+
from typing import List, NamedTuple, Optional, Tuple, Union
66

77
import torch
88
from pydantic import BaseModel
@@ -49,7 +49,14 @@ class LogprobParams(NamedTuple):
4949

5050
class LogprobMode(StrEnum):
5151
RAW = "raw"
52+
"""
53+
Return the raw log probabilities, i.e., the log probabilities calculated directly from the model output logits.
54+
"""
5255
PROCESSED = "processed"
56+
"""
57+
Return the processed log probabilities, i.e., the log probabilities after applying sampling parameters,
58+
such as temperature, top-k, top-p, etc.
59+
"""
5360

5461

5562
class LogitsProcessor(ABC):
@@ -180,7 +187,7 @@ class SamplingParams:
180187
181188
logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability.
182189
When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None.
183-
logprobs_mode (Literal["raw", "processed"]): The mode of log probabilities to return. Valid modes are "raw" and "processed". Defaults to "raw".
190+
logprobs_mode (LogprobMode, optional): The mode of log probabilities to return. Defaults to RAW.
184191
prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None.
185192
return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False.
186193
return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False.
@@ -227,7 +234,7 @@ class SamplingParams:
227234
n: int = 1
228235
best_of: Optional[int] = None
229236
use_beam_search: bool = False
230-
logprobs_mode: Literal["raw", "processed"] = "raw"
237+
logprobs_mode: LogprobMode = LogprobMode.RAW
231238

232239
# Keep the below fields in sync with tllme.SamplingConfig or maintin the mapping table.
233240
top_k: Optional[int] = None
@@ -330,11 +337,7 @@ def _validate(self):
330337
f"under the greedy decoding."
331338
)
332339

333-
if self.logprobs_mode not in [LogprobMode.RAW, LogprobMode.PROCESSED]:
334-
raise ValueError(
335-
f"logprobs_mode must be one of {LogprobMode.RAW.value}, {LogprobMode.PROCESSED.value}. "
336-
f"Got {self.logprobs_mode} instead."
337-
)
340+
self.logprobs_mode = LogprobMode(self.logprobs_mode)
338341

339342
if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1:
340343
raise ValueError(

0 commit comments

Comments
 (0)