Skip to content

Commit d416a39

Browse files
Address reviews
Signed-off-by: Yuan Tong <[email protected]>
1 parent c7d6f4f commit d416a39

File tree

4 files changed

+58
-61
lines changed

4 files changed

+58
-61
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def __init__(
486486
is_first_draft: bool = False,
487487
use_chunked_generation_logits: bool = True,
488488
logits_chunk_size: int = 8,
489-
logprobs_mode: LogprobMode | None = None,
489+
logprobs_mode: LogprobMode = LogprobMode.RAW,
490490
**kwargs):
491491

492492
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
@@ -568,7 +568,8 @@ def __init__(
568568
# currently, keep py_stop_words_list as python list, rather than tensor.
569569
self.py_stop_words_list = stop_words_list
570570

571-
self.py_logprobs_mode = LogprobMode.RAW if logprobs_mode is None else logprobs_mode
571+
self.py_logprobs_mode = LogprobMode(
572+
logprobs_mode) # handle passed a raw string
572573

573574
self.py_result = PyResult(
574575
prompt_len=self.py_prompt_len,
@@ -597,15 +598,6 @@ def set_exclude_last_generation_logits(
597598
self.py_result.set_exclude_last_generation_logits(
598599
exclude_last_generation_logits)
599600

600-
def validate_logprobs_mode(self):
601-
if self.py_logprobs_mode not in [
602-
LogprobMode.RAW, LogprobMode.PROCESSED
603-
]:
604-
raise ValueError(
605-
f"Invalid logprobs_mode: {self.py_logprobs_mode} "
606-
f"Expected one of {LogprobMode.RAW.value}, {LogprobMode.PROCESSED.value}"
607-
)
608-
609601
@property
610602
def cached_tokens(self) -> int:
611603
return self._cached_tokens
@@ -839,7 +831,8 @@ def executor_request_to_llm_request(
839831
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
840832
None),
841833
kv_cache_retention_config=executor_request.kv_cache_retention_config,
842-
logprobs_mode=getattr(executor_request, "py_logprobs_mode", None),
834+
logprobs_mode=getattr(executor_request, "py_logprobs_mode",
835+
LogprobMode.RAW),
843836
)
844837
if child_req_ids:
845838
for child_id in child_req_ids:

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,8 +1853,6 @@ def _validate_request(self, request: LlmRequest):
18531853
f"Request beam width {sampling_config.beam_width} "
18541854
f"is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
18551855
)
1856-
# Validate logprobs mode
1857-
request.validate_logprobs_mode()
18581856

18591857
# Check token ID ranges
18601858
if isinstance(self.model_engine.model, DecoderModelForCausalLM):

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import enum
1616
import sys
1717
from abc import ABC, abstractmethod
18-
from collections import defaultdict
18+
from collections import defaultdict, namedtuple
1919
from collections.abc import Iterable
2020
from concurrent import futures
2121
from dataclasses import dataclass
@@ -99,11 +99,14 @@ class LogProbsState:
9999

100100
@dataclass(kw_only=True)
101101
class LogProbsStateList:
102-
sampled_vals: list[list[list[float]]]
103-
sampled_indices: list[list[list[int]]]
104-
sampled_rank: list[list[list[int]]]
105-
topk_vals: list[list[list[float]]]
106-
topk_indices: list[list[list[int]]]
102+
FloatState = list[list[list[float]]]
103+
IntState = list[list[list[float]]]
104+
105+
sampled_vals: FloatState
106+
sampled_indices: IntState
107+
sampled_rank: IntState
108+
topk_vals: FloatState
109+
topk_indices: IntState
107110

108111
@staticmethod
109112
def from_logprobs_state(logprobs_state: LogProbsState) -> "LogProbsStateList":
@@ -241,7 +244,7 @@ class SampleStateWithMMResult:
241244
data: MultimodalResult
242245

243246

244-
@dataclass(kw_only=True, frozen=True)
247+
@dataclass(kw_only=True, frozen=True, slots=True)
245248
class RequestGroupKey(Generic[GenericStrategyKeyType]):
246249
strategy_key: GenericStrategyKeyType
247250
needs_probs: bool
@@ -420,10 +423,20 @@ def _group_requests_by_strategy_key(
420423
vocab_size: int,
421424
) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]:
422425
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
423-
group_dict: dict[
424-
tuple[GenericStrategyKeyType, bool],
425-
tuple[list[int], list[Strategy], list[int], list[bool], list[bool]],
426-
] = defaultdict(lambda: ([], [], [], [], []))
426+
RequestGroupValueBuilder = namedtuple(
427+
"RequestGroupValueBuilder",
428+
[
429+
"indices",
430+
"strategies",
431+
"speculation_needs_probs_list",
432+
"need_processed_logprobs_list",
433+
"need_raw_logprobs_list",
434+
],
435+
)
436+
437+
group_dict: dict[RequestGroupKey, RequestGroupValueBuilder] = defaultdict(
438+
lambda: RequestGroupValueBuilder([], [], [], [], [])
439+
)
427440

428441
for req_index, req in enumerate(requests):
429442
strategy = _request_strategy(req, vocab_size=vocab_size)
@@ -438,37 +451,30 @@ def _group_requests_by_strategy_key(
438451
need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs
439452
needs_probs = speculation_needs_probs or need_processed_logprobs
440453
strategy_key = strategy_to_key(strategy, needs_probs)
441-
group_dict_entry = group_dict[(strategy_key, needs_probs)]
442-
group_dict_entry[0].append(req_index)
443-
group_dict_entry[1].append(strategy)
454+
group_dict_entry = group_dict[
455+
RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs)
456+
]
457+
group_dict_entry.indices.append(req_index)
458+
group_dict_entry.strategies.append(strategy)
444459
if speculation_needs_probs:
445-
group_dict_entry[2].append(req_index)
446-
group_dict_entry[3].append(need_processed_logprobs)
447-
group_dict_entry[4].append(need_raw_logprobs)
460+
group_dict_entry.speculation_needs_probs_list.append(req_index)
461+
group_dict_entry.need_processed_logprobs_list.append(need_processed_logprobs)
462+
group_dict_entry.need_raw_logprobs_list.append(need_raw_logprobs)
448463
return {
449-
RequestGroupKey(
450-
strategy_key=group_key[0],
451-
needs_probs=group_key[1],
452-
): RequestGroupValue(
453-
indices=torch.tensor(indices, pin_memory=pin_memory, dtype=torch.int32),
454-
strategies=strategies,
464+
group_key: RequestGroupValue(
465+
indices=torch.tensor(group_value.indices, pin_memory=pin_memory, dtype=torch.int32),
466+
strategies=group_value.strategies,
455467
speculation_needs_probs_indices=torch.tensor(
456-
speculation_needs_probs_list, pin_memory=pin_memory, dtype=torch.int32
468+
group_value.speculation_needs_probs_list, pin_memory=pin_memory, dtype=torch.int32
457469
),
458470
need_processed_logprobs=torch.tensor(
459-
need_processed_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
471+
group_value.need_processed_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
460472
),
461473
need_raw_logprobs=torch.tensor(
462-
need_raw_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
474+
group_value.need_raw_logprobs_list, pin_memory=pin_memory, dtype=torch.bool
463475
),
464476
)
465-
for group_key, (
466-
indices,
467-
strategies,
468-
speculation_needs_probs_list,
469-
need_processed_logprobs_list,
470-
need_raw_logprobs_list,
471-
) in group_dict.items()
477+
for group_key, group_value in group_dict.items()
472478
}
473479

474480

@@ -967,8 +973,8 @@ def _create_store(self) -> Store:
967973
)
968974
sampled_log_prob_ranks = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.int32)
969975
# These are 0 sized tensors, if topk-logprobs are not used
970-
topk_indices = torch.empty(self.topk_logprobs_shape, device="cuda", dtype=torch.int32)
971-
topk_vals = torch.empty(self.topk_logprobs_shape, device="cuda", dtype=torch.float32)
976+
topk_indices = torch.empty(self.TOPK_LOGPROBS_SHAPE, device="cuda", dtype=torch.int32)
977+
topk_vals = torch.empty(self.TOPK_LOGPROBS_SHAPE, device="cuda", dtype=torch.float32)
972978

973979
# Only used for beam search
974980
cache_indirection: torch.Tensor | None = None
@@ -1034,7 +1040,7 @@ def __init__(self, args: Args):
10341040
self.max_seq_len + (0 if args.disable_overlap_scheduler else 1),
10351041
)
10361042
self.LOGPROBS_SHAPE = (self.max_num_sequences, self.max_beam_width, self.max_tokens)
1037-
self.topk_logprobs_shape = (self.max_num_sequences, self.max_tokens, self.max_topk_logprobs)
1043+
self.TOPK_LOGPROBS_SHAPE = (self.max_num_sequences, self.max_tokens, self.max_topk_logprobs)
10381044
# AutoDeploy build creates the sampler in inference mode,
10391045
# which would disallow in-place mutating of new_tokens.
10401046
# So, we temporarily exit inference mode.
@@ -2037,13 +2043,13 @@ def _prepare_log_probs(self, requests: list[LlmRequest]) -> None:
20372043
)
20382044
if self.max_topk_logprobs < self.batch_max_topk_logprobs:
20392045
self.max_topk_logprobs = self.batch_max_topk_logprobs
2040-
self.topk_logprobs_shape = (
2046+
self.TOPK_LOGPROBS_SHAPE = (
20412047
self.max_num_sequences,
20422048
self.max_tokens,
20432049
self.max_topk_logprobs,
20442050
)
2045-
self.store.topk_vals.resize_(self.topk_logprobs_shape)
2046-
self.store.topk_indices.resize_(self.topk_logprobs_shape)
2051+
self.store.topk_vals.resize_(self.TOPK_LOGPROBS_SHAPE)
2052+
self.store.topk_indices.resize_(self.TOPK_LOGPROBS_SHAPE)
20472053

20482054
@override
20492055
@torch.inference_mode()

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,10 @@ def sample(
472472
strategy: Strategy,
473473
logits: torch.Tensor,
474474
*,
475-
generator: Optional[torch.Generator] = None,
475+
generator: torch.Generator | None = None,
476476
group_metadata: StrategyMetadata | None = None,
477477
return_probs: bool = True,
478-
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]:
478+
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
479479
match strategy:
480480
case ("top_k", top_k, temperature):
481481
tokens, softmax = top_k_sampling_batch(
@@ -547,11 +547,11 @@ def sample_grouped_strategies(
547547
strategies: list[Strategy],
548548
logits: torch.Tensor,
549549
*,
550-
group_logit_indices: Optional[torch.Tensor] = None,
551-
generator: Optional[torch.Generator] = None,
550+
group_logit_indices: torch.Tensor | None = None,
551+
generator: torch.Generator | None = None,
552552
return_probs: bool,
553553
group_metadata: StrategyMetadata | None = None,
554-
) -> tuple[torch.Tensor, Optional[torch.Tensor], float | torch.Tensor | None]:
554+
) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]:
555555
raise NotImplementedError
556556

557557

@@ -581,11 +581,11 @@ def sample_grouped_strategies(
581581
strategies: list[Strategy],
582582
logits: torch.Tensor,
583583
*,
584-
group_logit_indices: Optional[torch.Tensor] = None,
585-
generator: Optional[torch.Generator] = None,
584+
group_logit_indices: torch.Tensor | None = None,
585+
generator: torch.Generator | None = None,
586586
return_probs: bool,
587587
group_metadata: StrategyMetadata | None = None,
588-
) -> tuple[torch.Tensor, Optional[torch.Tensor], float | None]:
588+
) -> tuple[torch.Tensor, torch.Tensor | None, float | None]:
589589
if group_key[0] == "beam_search":
590590
beam_width_in = group_key[1]
591591
else:

0 commit comments

Comments
 (0)