|
2 | 2 | import os |
3 | 3 | from abc import ABC, abstractmethod |
4 | 4 | from dataclasses import dataclass, field, fields |
5 | | -from typing import List, Literal, NamedTuple, Optional, Tuple, Union |
| 5 | +from typing import List, NamedTuple, Optional, Tuple, Union |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | from pydantic import BaseModel |
@@ -49,7 +49,14 @@ class LogprobParams(NamedTuple): |
49 | 49 |
|
50 | 50 | class LogprobMode(StrEnum): |
51 | 51 | RAW = "raw" |
| 52 | + """ |
| 53 | + Return the raw log probabilities, i.e., the log probabilities calculated directly from the model output logits. |
| 54 | + """ |
52 | 55 | PROCESSED = "processed" |
| 56 | + """ |
| 57 | + Return the processed log probabilities, i.e., the log probabilities after applying sampling parameters, |
| 58 | + such as temperature, top-k, top-p, etc. |
| 59 | + """ |
53 | 60 |
|
54 | 61 |
|
55 | 62 | class LogitsProcessor(ABC): |
@@ -180,7 +187,7 @@ class SamplingParams: |
180 | 187 |
|
181 | 188 | logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability. |
182 | 189 | When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None. |
183 | | - logprobs_mode (Literal["raw", "processed"]): The mode of log probabilities to return. Valid modes are "raw" and "processed". Defaults to "raw". |
| 190 | + logprobs_mode (LogprobMode, optional): The mode of log probabilities to return. Defaults to RAW. |
184 | 191 | prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None. |
185 | 192 | return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False. |
186 | 193 | return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False. |
@@ -227,7 +234,7 @@ class SamplingParams: |
227 | 234 | n: int = 1 |
228 | 235 | best_of: Optional[int] = None |
229 | 236 | use_beam_search: bool = False |
230 | | - logprobs_mode: Literal["raw", "processed"] = "raw" |
| 237 | + logprobs_mode: LogprobMode = LogprobMode.RAW |
231 | 238 |
|
232 | 239 | # Keep the below fields in sync with tllme.SamplingConfig or maintin the mapping table. |
233 | 240 | top_k: Optional[int] = None |
@@ -330,11 +337,7 @@ def _validate(self): |
330 | 337 | f"under the greedy decoding." |
331 | 338 | ) |
332 | 339 |
|
333 | | - if self.logprobs_mode not in [LogprobMode.RAW, LogprobMode.PROCESSED]: |
334 | | - raise ValueError( |
335 | | - f"logprobs_mode must be one of {LogprobMode.RAW.value}, {LogprobMode.PROCESSED.value}. " |
336 | | - f"Got {self.logprobs_mode} instead." |
337 | | - ) |
| 340 | + self.logprobs_mode = LogprobMode(self.logprobs_mode) |
338 | 341 |
|
339 | 342 | if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1: |
340 | 343 | raise ValueError( |
|
0 commit comments