diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9dd903c6ce..1eee56efe53 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1022,7 +1022,6 @@ common-files: &common_files | tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py | tests/unittest/_torch/sampler/test_beam_search.py | tests/unittest/_torch/sampler/test_best_of_n.py | - tests/unittest/_torch/sampler/test_return_logits.py | tests/unittest/_torch/sampler/test_torch_multi_arange.py | tests/unittest/_torch/sampler/test_trtllm_sampler.py | tests/unittest/_torch/speculative/test_draft_target.py | diff --git a/pyproject.toml b/pyproject.toml index 031d41850dc..647dc9406bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1063,7 +1063,6 @@ exclude = [ "tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py", "tests/unittest/_torch/sampler/test_beam_search.py", "tests/unittest/_torch/sampler/test_best_of_n.py", - "tests/unittest/_torch/sampler/test_return_logits.py", "tests/unittest/_torch/sampler/test_torch_multi_arange.py", "tests/unittest/_torch/sampler/test_trtllm_sampler.py", "tests/unittest/_torch/speculative/test_draft_target.py", diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 871fe4b9bf3..9145509d3b2 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -8,6 +8,7 @@ from tensorrt_llm._torch.shared_tensor import SharedTensorContainer from tensorrt_llm.bindings import executor as tllm_executor from tensorrt_llm.executor.result import TokenLogprobs +from tensorrt_llm.sampling_params import LogprobMode SamplingConfig = tensorrt_llm.bindings.SamplingConfig ''' @@ -485,6 +486,7 @@ def __init__( is_first_draft: bool = False, use_chunked_generation_logits: bool = True, logits_chunk_size: int = 8, + logprobs_mode: LogprobMode = LogprobMode.RAW, **kwargs): self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", @@ -566,6 +568,9 @@ def __init__( # currently, keep py_stop_words_list as python list, rather than tensor. self.py_stop_words_list = stop_words_list + self.py_logprobs_mode = LogprobMode( + logprobs_mode) # handle passed a raw string + self.py_result = PyResult( prompt_len=self.py_prompt_len, max_new_tokens=self.py_max_new_tokens, @@ -825,7 +830,10 @@ def executor_request_to_llm_request( arrival_time=getattr(executor_request, "py_arrival_time", None), py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), - kv_cache_retention_config=executor_request.kv_cache_retention_config) + kv_cache_retention_config=executor_request.kv_cache_retention_config, + logprobs_mode=getattr(executor_request, "py_logprobs_mode", + LogprobMode.RAW), + ) if child_req_ids: for child_id in child_req_ids: llm_request.create_child_request(child_id) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index c0bb0acf785..a155a5628c7 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -15,7 +15,7 @@ import enum import sys from abc import ABC, abstractmethod -from collections import defaultdict +from collections import defaultdict, namedtuple from collections.abc import Iterable from concurrent import futures from dataclasses import dataclass @@ -54,7 +54,7 @@ from tensorrt_llm.executor.result import Logprob from tensorrt_llm.llmapi.llm_args import KvCacheConfig from tensorrt_llm.mapping import Mapping -from tensorrt_llm.sampling_params import SamplingParams +from tensorrt_llm.sampling_params import LogprobMode, SamplingParams from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE from ..speculative.interface import get_force_num_accepted_tokens @@ -88,6 +88,37 @@ T = TypeVar("T") +@dataclass(kw_only=True) +class LogProbsState: + sampled_vals: torch.Tensor + sampled_indices: torch.Tensor + sampled_rank: torch.Tensor + topk_vals: torch.Tensor + topk_indices: torch.Tensor + + +@dataclass(kw_only=True) +class LogProbsStateList: + FloatState = list[list[list[float]]] + IntState = list[list[list[int]]] + + sampled_vals: FloatState + sampled_indices: IntState + sampled_rank: IntState + topk_vals: FloatState + topk_indices: IntState + + @staticmethod + def from_logprobs_state(logprobs_state: LogProbsState) -> "LogProbsStateList": + return LogProbsStateList( + sampled_vals=logprobs_state.sampled_vals.tolist(), + sampled_indices=logprobs_state.sampled_indices.tolist(), + topk_vals=logprobs_state.topk_vals.tolist(), + topk_indices=logprobs_state.topk_indices.tolist(), + sampled_rank=logprobs_state.sampled_rank.tolist(), + ) + + @dataclass(kw_only=True) class SampleStateTensors: new_tokens: torch.Tensor @@ -213,13 +244,13 @@ class SampleStateWithMMResult: data: MultimodalResult -@dataclass(kw_only=True, frozen=True) +@dataclass(kw_only=True, frozen=True, slots=True) class RequestGroupKey(Generic[GenericStrategyKeyType]): strategy_key: GenericStrategyKeyType - speculation_needs_probs: bool + needs_probs: bool def __iter__(self): - return iter((self.strategy_key, self.speculation_needs_probs)) + return iter((self.strategy_key, self.needs_probs)) def __len__(self): return 2 @@ -229,25 +260,45 @@ def __len__(self): class RequestGroupValue: indices: torch.Tensor strategies: list[Strategy] + speculation_needs_probs_indices: torch.Tensor + need_processed_logprobs: torch.Tensor + need_raw_logprobs: torch.Tensor def __iter__(self): - return iter((self.indices, self.strategies)) + return iter( + ( + self.indices, + self.strategies, + self.speculation_needs_probs_indices, + self.need_processed_logprobs, + self.need_raw_logprobs, + ) + ) def __len__(self): - return 2 + return 5 @dataclass(kw_only=True, frozen=True) class RequestGroupValueWithMetadata(RequestGroupValue): - metadata: StrategyMetadata + metadata: StrategyMetadata | None @override def __iter__(self): - return iter((self.indices, self.strategies, self.metadata)) + return iter( + ( + self.indices, + self.strategies, + self.speculation_needs_probs_indices, + self.need_processed_logprobs, + self.need_raw_logprobs, + self.metadata, + ) + ) @override def __len__(self): - return 3 + return 6 class EarlyStopWithMMResult(Sampler): @@ -338,6 +389,10 @@ def _get_max_beam_width(request: LlmRequest) -> int: return max_beam_width +def _request_sampling_params_cachable(params: UtilsSamplingParams) -> bool: + return not params.use_beam_search + + def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams: sampling_config = request.sampling_config temperature = _unwrap_singleton(cast(Optional[list[float]], sampling_config.temperature)) @@ -358,8 +413,16 @@ def _request_get_sampling_params(request: LlmRequest) -> UtilsSamplingParams: def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: + # We try to cache the resolved strategy on the request object, as it's not cheap enough to + # resolve it on every iteration. + if hasattr(request, "py_sampling_strategy"): + return request.py_sampling_strategy + params = _request_get_sampling_params(request) - return resolve_sampling_strategy(params, vocab_size=vocab_size) + sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) + if _request_sampling_params_cachable(params): + request.py_sampling_strategy = resolve_sampling_strategy(params, vocab_size=vocab_size) + return sampling_strategy def _group_requests_by_strategy_key( @@ -370,8 +433,19 @@ def _group_requests_by_strategy_key( vocab_size: int, ) -> dict[RequestGroupKey[GenericStrategyKeyType], RequestGroupValue]: # NB: Client code relies on request indices in returned torch.Tensor being sorted. - group_dict: dict[tuple[GenericStrategyKeyType, bool], tuple[list[int], list[Strategy]]] = ( - defaultdict(lambda: ([], [])) + RequestGroupValueBuilder = namedtuple( + "RequestGroupValueBuilder", + [ + "indices", + "strategies", + "speculation_needs_probs_list", + "need_processed_logprobs_list", + "need_raw_logprobs_list", + ], + ) + + group_dict: dict[RequestGroupKey, RequestGroupValueBuilder] = defaultdict( + lambda: RequestGroupValueBuilder([], [], [], [], []) ) for req_index, req in enumerate(requests): @@ -381,18 +455,36 @@ def _group_requests_by_strategy_key( # process_draft_tokens. TorchSampler._speculation_could_use_rejection_sampling(req, strategy) ) - strategy_key = strategy_to_key(strategy, speculation_needs_probs) - group_dict_entry = group_dict[(strategy_key, speculation_needs_probs)] - group_dict_entry[0].append(req_index) - group_dict_entry[1].append(strategy) + need_processed_logprobs = ( + req.py_logprobs_mode == LogprobMode.PROCESSED and req.return_log_probs + ) + need_raw_logprobs = req.py_logprobs_mode == LogprobMode.RAW and req.return_log_probs + needs_probs = speculation_needs_probs or need_processed_logprobs + strategy_key = strategy_to_key(strategy, needs_probs) + group_dict_entry = group_dict[ + RequestGroupKey(strategy_key=strategy_key, needs_probs=needs_probs) + ] + group_dict_entry.indices.append(req_index) + group_dict_entry.strategies.append(strategy) + if speculation_needs_probs: + group_dict_entry.speculation_needs_probs_list.append(req_index) + group_dict_entry.need_processed_logprobs_list.append(need_processed_logprobs) + group_dict_entry.need_raw_logprobs_list.append(need_raw_logprobs) return { - RequestGroupKey( - strategy_key=group_key[0], speculation_needs_probs=group_key[1] - ): RequestGroupValue( - indices=torch.tensor(indices, pin_memory=pin_memory, dtype=torch.int32), - strategies=strategies, + group_key: RequestGroupValue( + indices=torch.tensor(group_value.indices, pin_memory=pin_memory, dtype=torch.int32), + strategies=group_value.strategies, + speculation_needs_probs_indices=torch.tensor( + group_value.speculation_needs_probs_list, pin_memory=pin_memory, dtype=torch.int32 + ), + need_processed_logprobs=torch.tensor( + group_value.need_processed_logprobs_list, pin_memory=pin_memory, dtype=torch.bool + ), + need_raw_logprobs=torch.tensor( + group_value.need_raw_logprobs_list, pin_memory=pin_memory, dtype=torch.bool + ), ) - for group_key, (indices, strategies) in group_dict.items() + for group_key, group_value in group_dict.items() } @@ -417,6 +509,8 @@ class _BatchedSamplingResult: batch_req_indices: torch.Tensor # Next tokens for all requests: batch_next_tokens_cuda_int: torch.Tensor + # Logits for all requests used for logprobs: + batch_logits_for_logprobs_cuda: torch.Tensor | None = None # Helper class for _PackedStepIndexer and _UnpackedStepIndexer, facilitating the @@ -694,6 +788,7 @@ class BeamHistory: @dataclass(kw_only=True) class SamplingRequestsMetadata: req_num_generated_tokens: torch.Tensor + req_num_generated_tokens_output: torch.Tensor req_num_beams: torch.Tensor req_num_steps: torch.Tensor req_offsets: torch.Tensor @@ -703,6 +798,7 @@ class SamplingRequestsMetadata: class SampleStateTensorsHostTorch(SampleStateTensors): finish_reasons: torch.Tensor first_finish_reasons: torch.Tensor + logprobs_state: LogProbsState | None = None def finish_reasons_list(self) -> FinishReasonsList: """`(num_seq_slots, num_steps)`""" @@ -819,6 +915,7 @@ def _record_sampler_event(self) -> SamplerEvent: class TorchSampler(Sampler, AsyncWorkerMixin): SampleState = SampleStateTorch + DEFAULT_MAX_TOPK_LOGPROBS = 20 @override def get_cache_indirection(self) -> torch.Tensor | None: @@ -828,7 +925,7 @@ def get_cache_indirection(self) -> torch.Tensor | None: def is_generation_model(self) -> bool: return True - @dataclass(frozen=True, kw_only=True) + @dataclass(kw_only=True) class Store: new_tokens: torch.Tensor """Shape: See cpp DecoderState.getAllNewTokens()""" @@ -844,9 +941,21 @@ class Store: cum_log_probs: torch.Tensor | None = None """Shape: batch_size, beam_width Usage: Stores the current cumulative logprob of each active beam for faster sampling""" - new_log_probs: torch.Tensor | None = None - """Shape: batch_size, beam_width - Usage: Stores the new logprobs for each beam""" + sampled_log_prob_indices: torch.Tensor | None = None + """Shape: batch_size, beam_width, max_tokens + Usage: Stores the token indices of the sampled logprobs""" + sampled_log_probs: torch.Tensor | None = None + """Shape: batch_size, beam_width, max_tokens + Usage: Stores the values of the sampled logprobs""" + sampled_log_prob_ranks: torch.Tensor | None = None + """Shape: batch_size, beam_width, max_tokens + Usage: Stores the ranks of the sampled logprobs""" + topk_indices: torch.Tensor | None = None + """Shape: batch_size, max_tokens, max_topk_logprobs + Usage: Stores the token indices of the topk logprobs""" + topk_vals: torch.Tensor | None = None + """Shape: batch_size, max_tokens, max_topk_logprobs + Usage: Stores the values of the topk logprobs""" first_finish_reasons: torch.Tensor | None = None """Shape: batch_size, beam_width Usage: Stores the first finish reason for each beam""" @@ -862,31 +971,54 @@ def __post_init__(self): assert self.new_tokens.shape == self.finish_reasons.shape def _create_store(self) -> Store: + # Tensors necessary for all sampling methods + new_tokens = int_tensor(self.NEW_TOKENS_SHAPE) + finish_reasons = int_tensor(self.NEW_TOKENS_SHAPE) + + # Only used for logprobs processing or beam search + sampled_log_probs = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.float32) + # Only used for logprobs processing + sampled_log_prob_indices = torch.empty( + self.LOGPROBS_SHAPE, device="cuda", dtype=torch.int32 + ) + sampled_log_prob_ranks = torch.empty(self.LOGPROBS_SHAPE, device="cuda", dtype=torch.int32) + # These are 0 sized tensors, if topk-logprobs are not used + topk_indices = torch.empty(self.TOPK_LOGPROBS_SHAPE, device="cuda", dtype=torch.int32) + topk_vals = torch.empty(self.TOPK_LOGPROBS_SHAPE, device="cuda", dtype=torch.float32) + + # Only used for beam search + cache_indirection: torch.Tensor | None = None + cache_indirection_buffer: torch.Tensor | None = None + cum_log_probs: torch.Tensor | None = None + predecessor_beams: torch.Tensor | None = None + original_tokens: torch.Tensor | None = None + first_finish_reasons: torch.Tensor | None = None if self._use_beam_search: - return self.Store( - new_tokens=int_tensor(self.NEW_TOKENS_SHAPE), - finish_reasons=int_tensor(self.NEW_TOKENS_SHAPE), - cache_indirection=torch.zeros( - self.CACHE_INDIRECTION_SHAPE, device="cuda", dtype=torch.int - ), - cache_indirection_buffer=int_tensor(self.CACHE_INDIRECTION_SHAPE), - cum_log_probs=torch.zeros( - self.CACHE_INDIRECTION_SHAPE[:-1], device="cuda", dtype=torch.float32 - ), - new_log_probs=torch.zeros( - self.CACHE_INDIRECTION_SHAPE[:-1], device="cuda", dtype=torch.float32 - ), - predecessor_beams=int_tensor(self.CACHE_INDIRECTION_SHAPE[:-1]), - original_tokens=int_tensor(self.CACHE_INDIRECTION_SHAPE), - first_finish_reasons=int_tensor( - self.CACHE_INDIRECTION_SHAPE[:-1], - ), + cache_indirection = torch.empty( + self.CACHE_INDIRECTION_SHAPE, device="cuda", dtype=torch.int ) - else: - return self.Store( - new_tokens=int_tensor(self.NEW_TOKENS_SHAPE), - finish_reasons=int_tensor(self.NEW_TOKENS_SHAPE), + cache_indirection_buffer = int_tensor(self.CACHE_INDIRECTION_SHAPE) + cum_log_probs = torch.empty( + self.CACHE_INDIRECTION_SHAPE[:-1], device="cuda", dtype=torch.float32 ) + predecessor_beams = int_tensor(self.CACHE_INDIRECTION_SHAPE[:-1]) + original_tokens = int_tensor(self.CACHE_INDIRECTION_SHAPE) + first_finish_reasons = int_tensor(self.CACHE_INDIRECTION_SHAPE[:-1]) + return self.Store( + new_tokens=new_tokens, + finish_reasons=finish_reasons, + cache_indirection=cache_indirection, + cache_indirection_buffer=cache_indirection_buffer, + cum_log_probs=cum_log_probs, + sampled_log_prob_indices=sampled_log_prob_indices, + sampled_log_probs=sampled_log_probs, + sampled_log_prob_ranks=sampled_log_prob_ranks, + topk_indices=topk_indices, + topk_vals=topk_vals, + predecessor_beams=predecessor_beams, + original_tokens=original_tokens, + first_finish_reasons=first_finish_reasons, + ) @dataclass(frozen=True, kw_only=True) class Args: @@ -903,6 +1035,10 @@ def __init__(self, args: Args): self.max_seq_len = args.max_seq_len self.max_tokens = args.max_total_draft_tokens + 1 self.max_beam_width = args.max_beam_width + # The current maximum number of topk logprobs which can be stored in the sampler's store + self.max_topk_logprobs = self.DEFAULT_MAX_TOPK_LOGPROBS + # The maximum number of topk logprobs for the current batch of requests + self.batch_max_topk_logprobs = 0 if args.max_total_draft_tokens > 0 and args.max_beam_width > 1: raise ValueError("TorchSampler does not support beam search with speculative decoding") self.max_num_sequences = args.max_num_sequences @@ -913,6 +1049,8 @@ def __init__(self, args: Args): self.max_beam_width, self.max_seq_len + (0 if args.disable_overlap_scheduler else 1), ) + self.LOGPROBS_SHAPE = (self.max_num_sequences, self.max_beam_width, self.max_tokens) + self.TOPK_LOGPROBS_SHAPE = (self.max_num_sequences, self.max_tokens, self.max_topk_logprobs) # AutoDeploy build creates the sampler in inference mode, # which would disallow in-place mutating of new_tokens. # So, we temporarily exit inference mode. @@ -1113,61 +1251,132 @@ def _convert_logprobs_tensor_to_list( self, token_tensor: torch.Tensor, logprobs_tensor: torch.Tensor, + sampled_log_probs_indices: torch.Tensor | None, + sampled_log_probs_vals: torch.Tensor | None, + sampled_log_probs_rank: torch.Tensor | None, ) -> list[list[dict[int, Logprob]]]: """Convert the logprobs tensor to a list of lists of dictionaries of Logprob objects Logprobs storage expects logprobs as a list[list[dict[int, Logprob]]] object args: + token_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs logprobs_tensor: torch.Tensor. Shape: beam_width, num_tokens, num_logprobs + sampled_log_probs_indices: torch.Tensor | None. Shape: num_tokens + sampled_log_probs_vals: torch.Tensor | None. Shape: num_tokens + sampled_log_probs_rank: torch.Tensor | None. Shape: num_tokens output: - list[list[dict[int, Logprob]]]. Shape: beam_width, num_tokens, dict with num_logprobs keys + list[list[dict[int, Logprob]]]. Shape: (beam_width, num_tokens) """ assert token_tensor.dim() == 3 and logprobs_tensor.dim() == 3, ( f"Token and logprobs tensors must have 3 dimensions (beam_width, num_tokens, num_logprobs). \ Got shapes (token_tensor) {token_tensor.shape} and (logprobs_tensor) {logprobs_tensor.shape} instead" ) - return [ - [ - { + + token_log_probs: list[list[dict[int, Logprob]]] = [] + token_list = token_tensor.tolist() + logprobs_list = logprobs_tensor.tolist() + sampled_log_probs_indices_list: list[int] | None = None + sampled_log_probs_vals_list: list[float] | None = None + sampled_log_probs_rank_list: list[int] | None = None + if sampled_log_probs_indices is not None: + sampled_log_probs_indices_list = sampled_log_probs_indices.tolist() + assert sampled_log_probs_vals is not None, "sampled_log_probs_vals must be provided" + assert sampled_log_probs_rank is not None, "sampled_log_probs_rank must be provided" + sampled_log_probs_vals_list = sampled_log_probs_vals.tolist() + sampled_log_probs_rank_list = sampled_log_probs_rank.tolist() + for beam_idx in range(token_tensor.shape[0]): + beam_token_log_probs: list[dict[int, Logprob]] = [] + for step_idx, (topk_token, topk_logprob) in enumerate( + zip(token_list[beam_idx], logprobs_list[beam_idx]) + ): + logprobs = { + token: Logprob(logprob=logprob, rank=rank + 1) + for rank, (token, logprob) in enumerate(zip(topk_token, topk_logprob)) + } + if sampled_log_probs_indices is not None: + assert beam_idx == DEFAULT_BEAM_IDX, ( + "beam search does not need to explicitly handle sampled log probs" + ) + if sampled_log_probs_indices_list[step_idx] not in logprobs: + logprobs[sampled_log_probs_indices_list[step_idx]] = Logprob( + logprob=sampled_log_probs_vals_list[step_idx], + rank=max( + token_tensor.shape[2] + 1, sampled_log_probs_rank_list[step_idx] + ), + ) + beam_token_log_probs.append(logprobs) + token_log_probs.append(beam_token_log_probs) + + return token_log_probs + + def _store_logprobs_list_to_request( + self, + logprobs_state_list: LogProbsStateList, + req_seq_slot: int, + beam_width: int, + count: int, + num_topk_logprobs: int, + ) -> list[list[dict[int, Logprob]]]: + """Convert the LogProbsStateList object to a list of lists of dictionaries of Logprob objects + + Logprobs storage expects logprobs as a list[list[dict[int, Logprob]]] object + + args: + logprobs_state_list: LogProbsStateList. Contains the topk indices, topk values, + sampled indices, sampled values, and sampled ranks. + req_seq_slot: int. The sequence slot of the request. + beam_width: int. The beam width of the request. + count: int. The number of tokens to store. + num_topk_logprobs: int. The number of topk logprobs of each token. + output: + list[list[dict[int, Logprob]]]. Shape: (beam_width, count) + """ + + token_list = logprobs_state_list.topk_indices[req_seq_slot] + logprobs_list = logprobs_state_list.topk_vals[req_seq_slot] + sampled_log_probs_indices_list = logprobs_state_list.sampled_indices[req_seq_slot] + sampled_log_probs_vals_list = logprobs_state_list.sampled_vals[req_seq_slot] + sampled_log_probs_rank_list = logprobs_state_list.sampled_rank[req_seq_slot] + + token_log_probs: list[list[dict[int, Logprob]]] = [] + for beam_idx in range(beam_width): + beam_token_log_probs: list[dict[int, Logprob]] = [] + for step_idx, (topk_token, topk_logprob) in enumerate( + zip(token_list[:count], logprobs_list[:count]) + ): + logprobs = { token: Logprob(logprob=logprob, rank=rank + 1) for rank, (token, logprob) in enumerate( - zip(topk_token.tolist(), topk_logprob.tolist()) + zip(topk_token[:num_topk_logprobs], topk_logprob[:num_topk_logprobs]) ) } - for topk_token, topk_logprob in zip( - token_tensor[beam_idx], logprobs_tensor[beam_idx] - ) - ] - for beam_idx in range(token_tensor.shape[0]) - ] + if sampled_log_probs_indices_list[beam_idx][step_idx] not in logprobs: + logprobs[sampled_log_probs_indices_list[beam_idx][step_idx]] = Logprob( + logprob=sampled_log_probs_vals_list[beam_idx][step_idx], + rank=max( + len(token_list[step_idx]) + 1, + sampled_log_probs_rank_list[beam_idx][step_idx] + 1, + ), + ) + beam_token_log_probs.append(logprobs) + token_log_probs.append(beam_token_log_probs) + + return token_log_probs def handle_logprobs( self, request: LlmRequest, + logprobs_state_list: LogProbsStateList | None, *, count: int, ): if request.py_return_log_probs: beam_width = request.sampling_config.beam_width - if self._use_beam_search: - topk_log_probs_vals = self.store.new_log_probs[request.py_seq_slot].view( - beam_width, count, -1 - ) - topk_log_probs_indices = self.store.new_tokens[0, request.py_seq_slot].view( - beam_width, count, -1 - ) - else: - assert beam_width == 1, "beam width must be 1 for non-beam search" - topk_log_probs_vals = request.py_topk_logprobs_vals[: count * beam_width].view( - beam_width, count, -1 - ) - topk_log_probs_indices = request.py_topk_logprobs_indices[ - : count * beam_width - ].view(beam_width, count, -1) - - token_log_probs = self._convert_logprobs_tensor_to_list( - topk_log_probs_indices, topk_log_probs_vals + assert request.py_num_logprobs is not None, "request.py_num_logprobs must be provided" + assert logprobs_state_list is not None, "logprobs_state_list must be provided" + token_log_probs = self._store_logprobs_list_to_request( + logprobs_state_list, request.py_seq_slot, beam_width, count, request.py_num_logprobs ) request.py_result.append_log_probs(token_log_probs) @@ -1344,7 +1553,8 @@ def _prepare_beam_search( raise ValueError("Beam search does not support multiple logprobs") self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0) self.store.cum_log_probs[request.py_seq_slot].fill_(0) - self.store.new_log_probs[request.py_seq_slot].fill_(0) + self.store.sampled_log_probs[request.py_seq_slot].fill_(0) + self.store.sampled_log_prob_ranks[request.py_seq_slot].fill_(0) self.store.predecessor_beams[request.py_seq_slot].fill_(0) self.store.first_finish_reasons[request.py_seq_slot].fill_( FinishReason.NOT_FINISHED.value @@ -1371,7 +1581,7 @@ def _process_draft_tokens_rejection_sampling( else _request_strategy(request, vocab_size=2**31) ) generator = self.get_generator(request.py_draft_logits.device) - _, draft_probs = sample( + _, draft_probs, _ = sample( draft_sampling_strategy, request.py_draft_logits, generator=generator, @@ -1490,12 +1700,14 @@ def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor, logprobs_indices_tensor: A tensor of shape (beam_width, num_generated_tokens, num_logprobs) """ num_generated_tokens = request.max_beam_num_tokens - request.py_prompt_len - assert request.py_num_logprobs == 1, "Beam search only supports one logprob per token" + assert request.py_num_logprobs == 0, ( + "Beam search only supports returning the sampled logprob per token" + ) logprobs_tensor = torch.empty( ( request.sampling_config.beam_width, num_generated_tokens, - request.py_num_logprobs, + request.py_num_logprobs + 1, ), device="cuda", dtype=torch.float32, @@ -1504,7 +1716,7 @@ def _get_logprobs_from_request(self, request: LlmRequest) -> tuple[torch.Tensor, ( request.sampling_config.beam_width, num_generated_tokens, - request.py_num_logprobs, + request.py_num_logprobs + 1, ), device="cuda", dtype=torch.int32, @@ -1555,7 +1767,7 @@ def _create_beam_history( current_logprobs = torch.cat( [ current_logprobs, - self.store.new_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1), + self.store.sampled_log_probs[request.py_seq_slot, :num_beams].view(-1, 1, 1), ], dim=1, ) @@ -1647,6 +1859,9 @@ def _finalize_beam( beam_idx : beam_idx + 1, : valid_tokens[beam_idx] ], beam_history.logprobs[beam_idx : beam_idx + 1, : valid_tokens[beam_idx]], + None, + None, + None, )[0] ) request.set_generated_tokens(gen_token_list) @@ -1676,7 +1891,7 @@ def _add_metadata_to_grouped_requests( cache_indirection=self.store.cache_indirection, cache_indirection_buffer=self.store.cache_indirection_buffer, cum_log_probs=self.store.cum_log_probs, - new_log_probs=self.store.new_log_probs, + new_log_probs=self.store.sampled_log_probs[..., DEFAULT_STEP_IDX], seq_slots=seq_slots[grouped_requests[key].indices].to( device="cuda", dtype=torch.int64, non_blocking=True ), # Should be on device for beam search, need long for index_copy_ @@ -1702,6 +1917,9 @@ def _add_metadata_to_grouped_requests( grouped_requests_with_metadata[key] = RequestGroupValueWithMetadata( indices=value.indices, strategies=value.strategies, + speculation_needs_probs_indices=value.speculation_needs_probs_indices, + need_processed_logprobs=value.need_processed_logprobs, + need_raw_logprobs=value.need_raw_logprobs, metadata=metadata, ) return grouped_requests_with_metadata @@ -1763,6 +1981,9 @@ def update_requests( new_tokens_list = new_tokens.tolist() beam_histories = state.beam_histories + logprobs_state_list: LogProbsStateList | None = None + if state.host.logprobs_state is not None: + logprobs_state_list = LogProbsStateList.from_logprobs_state(state.host.logprobs_state) for req_idx, req in enumerate(state.scheduled_requests.context_requests): if ( @@ -1778,7 +1999,7 @@ def update_requests( else: for beam_idx in range(req.sampling_config.beam_width): add_token(req, new_tokens_list, beam_idx=beam_idx) - self.handle_logprobs(req, count=1) + self.handle_logprobs(req, logprobs_state_list=logprobs_state_list, count=1) self._handle_finish_reasons(req, state.host.finish_reasons, finish_reasons) req.py_decoding_iter += 1 @@ -1798,7 +2019,7 @@ def update_requests( for beam_idx in range(req.sampling_config.beam_width): # Beam search does not support speculative decoding. add_token(req, new_tokens_list, beam_idx=beam_idx) - self.handle_logprobs(req, count=1) + self.handle_logprobs(req, logprobs_state_list=logprobs_state_list, count=1) self._handle_finish_reasons(req, state.host.finish_reasons, finish_reasons) req.py_num_accepted_draft_tokens = 0 req.py_rewind_len = 0 @@ -1820,12 +2041,26 @@ def update_requests( req.py_num_accepted_draft_tokens = 0 req.py_rewind_len = 0 processed += num_accepted - self.handle_logprobs(req, count=processed) + self.handle_logprobs(req, logprobs_state_list=logprobs_state_list, count=processed) req.py_decoding_iter += 1 def _return_log_probs(self, requests: list[LlmRequest]) -> bool: return any(req.py_return_log_probs for req in requests) + def _prepare_log_probs(self, requests: list[LlmRequest]) -> None: + self.batch_max_topk_logprobs = max( + (req.py_num_logprobs or 0 for req in requests), default=0 + ) + if self.max_topk_logprobs < self.batch_max_topk_logprobs: + self.max_topk_logprobs = self.batch_max_topk_logprobs + self.TOPK_LOGPROBS_SHAPE = ( + self.max_num_sequences, + self.max_tokens, + self.max_topk_logprobs, + ) + self.store.topk_vals.resize_(self.TOPK_LOGPROBS_SHAPE) + self.store.topk_indices.resize_(self.TOPK_LOGPROBS_SHAPE) + @override @torch.inference_mode() @nvtx_range("sample_async") @@ -1883,6 +2118,9 @@ def sample_async( beam_histories = [None] * len(requests) if self._use_beam_search: assert seq_lens_host is not None, "seq_lens is required for beam search" + assert self.store.first_finish_reasons is not None, ( + "first_finish_reasons must be provided" + ) seq_lens = seq_lens_host.to(device="cuda", non_blocking=True) first_finish_reasons_host = self._copy_to_host(self.store.first_finish_reasons) self._update_original_tokens(seq_slots, seq_lens, new_tokens) @@ -1890,6 +2128,35 @@ def sample_async( requests, finish_reasons=first_finish_reasons, beam_histories=beam_histories ) + # copy logprobs to host + logprobs_state: LogProbsState | None = None + if self._return_log_probs(requests): + assert self.store.topk_vals is not None, "topk_vals must be provided" + assert self.store.topk_indices is not None, "topk_indices must be provided" + assert self.store.sampled_log_probs is not None, "sampled_log_probs must be provided" + assert self.store.sampled_log_prob_indices is not None, ( + "sampled_log_prob_indices must be provided" + ) + assert self.store.sampled_log_prob_ranks is not None, ( + "sampled_log_prob_ranks must be provided" + ) + host_topk_vals = self._copy_to_host( + self.store.topk_vals[..., : self.batch_max_topk_logprobs] + ) + host_topk_indices = self._copy_to_host( + self.store.topk_indices[..., : self.batch_max_topk_logprobs] + ) + host_sampled_vals = self._copy_to_host(self.store.sampled_log_probs) + host_sampled_indices = self._copy_to_host(self.store.sampled_log_prob_indices) + host_sampled_rank = self._copy_to_host(self.store.sampled_log_prob_ranks) + logprobs_state = LogProbsState( + topk_vals=host_topk_vals, + topk_indices=host_topk_indices, + sampled_vals=host_sampled_vals, + sampled_indices=host_sampled_indices, + sampled_rank=host_sampled_rank, + ) + sampler_event = self._record_sampler_event() return SampleStateTorch( scheduled_requests=scheduled_requests, @@ -1900,6 +2167,7 @@ def sample_async( first_finish_reasons=None if not self._use_beam_search else first_finish_reasons_host, + logprobs_state=logprobs_state, ), sampler_event=sampler_event, beam_histories=beam_histories, @@ -1972,8 +2240,8 @@ def _apply_embedding_bias( # Since read-caching is expected to help in typical cases, option (ii) is implemented here. # Track which logits require logit bias application - logits_bias_mask = torch.zeros((logits.size(0),), dtype=torch.bool, pin_memory=True) - + request_steps_list = request_steps.tolist() + logits_bias_masks = [False] * logits.size(0) _next_bias_index = 0 def provision_bias_index() -> int: @@ -1993,11 +2261,11 @@ def provision_bias_index() -> int: # Collect bias information req_bias = None - for i, (req, steps) in enumerate(zip(requests, request_steps)): - steps = int(steps.item()) + for i, (req, steps) in enumerate(zip(requests, request_steps_list)): req_bias = req._py_embedding_bias_1d if req_bias is not None: - logits_bias_mask[i : (i + steps)] = True + for j in range(i, i + steps): + logits_bias_masks[j] = True req_bias_index = bias_to_index[req_bias] bias_gather_indices.extend(repeat(req_bias_index, steps)) @@ -2008,7 +2276,9 @@ def provision_bias_index() -> int: bias_gather_indices_cuda = torch.tensor( bias_gather_indices, pin_memory=True, dtype=torch.int32 ).to(logits.device, non_blocking=True) - logits_bias_mask_cuda = logits_bias_mask.to(logits.device, non_blocking=True) + logits_bias_mask_cuda = torch.tensor( + logits_bias_masks, pin_memory=True, dtype=torch.bool + ).to(logits.device, non_blocking=True) biases_tensor = torch.empty((len(bias_to_index), *req_bias.shape), pin_memory=True) biases_tensor = torch.stack( tuple(bias_to_index.keys()), @@ -2024,63 +2294,6 @@ def provision_bias_index() -> int: # sharing). logits[logits_bias_mask_cuda] += biases_tensor_cuda - def _handle_log_probs( - self, - requests: list[LlmRequest], - logits_cuda: torch.Tensor, - *, - logits_cuda_indexer: _PackedStepIndexer, - req_num_generated_tokens: torch.Tensor, - ) -> None: - """Handle top-k logprobs. - - This is done outside the sampling loop, because the returned logprobs are specified to not reflect - temperature scaling, top-k/top-p masking, etc. - """ - if self._return_log_probs(requests): - assert logits_cuda.dim() == 2, "logits should be 2D" - - logprobs_req_indices = [ - req_id for req_id, req in enumerate(requests) if req.py_num_logprobs - ] - logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] - logprobs_logit_indices_cuda = logprobs_logit_indices.to( - device=logits_cuda.device, non_blocking=True - ) - logprobs_cuda = F.log_softmax( - logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True), - dim=-1, - ) - topk_vals_cuda, topk_indices_cuda = torch.topk( - logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 - ) - # Use a single D2H copy to reduce overheads - topk_vals = self._copy_to_host(topk_vals_cuda) - topk_indices = self._copy_to_host(topk_indices_cuda) - current_offset = 0 - for req_id, steps in zip( - logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist() - ): - req = requests[req_id] - next_offset = current_offset + steps - # NB: Assigning views on memory which is being filled asynchronously - req.py_topk_logprobs_vals = topk_vals[ - current_offset:next_offset, : req.py_num_logprobs - ] - req.py_topk_logprobs_indices = topk_indices[ - current_offset:next_offset, : req.py_num_logprobs - ] - - # context requests do not have multiple input beams, but they need multiple output beams - if req.is_context_init_state: - req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand( - req.sampling_config.beam_width, -1 - ) - req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand( - req.sampling_config.beam_width, -1 - ) - current_offset = next_offset - @nvtx_range("sample_batched_by_strategy") @torch.inference_mode() def _sample_batched_by_strategy( @@ -2097,6 +2310,7 @@ def _sample_batched_by_strategy( seq_slots: torch.Tensor, seq_lens: Optional[torch.Tensor] = None, token_dtype: torch.dtype, + return_log_probs: bool, ) -> _BatchedSamplingResult: grouped_requests = _group_requests_by_strategy_key( requests, @@ -2130,11 +2344,21 @@ def _sample_batched_by_strategy( batch_next_tokens_cuda_int = torch.empty( (logits_cuda.size(0), self.max_beam_width), device=cuda_device, dtype=token_dtype ) + batch_logits_for_logprobs_cuda = ( + torch.empty( + (logits_cuda.size(0), logits_cuda.size(1)), device=cuda_device, dtype=torch.float32 + ) + if return_log_probs + else None + ) batch_req_idx_offset_start = 0 batch_next_tokens_offset_start = 0 - for (strategy_key, speculation_needs_probs), ( + for (strategy_key, needs_probs), ( group_req_indices, group_strategies, + group_speculation_needs_probs_indices, + group_need_processed_logprobs, + group_need_raw_logprobs, group_metadata, ) in grouped_requests_with_metadata.items(): # group_req_indices: Indices of 'requests' entries having the same sampling @@ -2144,6 +2368,35 @@ def _sample_batched_by_strategy( group_req_indices ) + need_processed_logprobs_indices = torch.nonzero(group_need_processed_logprobs) + need_raw_logprobs_indices = torch.nonzero(group_need_raw_logprobs) + any_request_needs_processed_logprobs = need_processed_logprobs_indices.size(0) > 0 + any_request_needs_raw_logprobs = need_raw_logprobs_indices.size(0) > 0 + any_request_needs_logprobs = ( + any_request_needs_processed_logprobs or any_request_needs_raw_logprobs + ) + + if any_request_needs_logprobs: + # indices for accessing logits within the current group + group_logit_indexer = _PackedStepIndexer( + num_steps=req_num_generated_tokens[group_req_indices], + max_steps=req_num_generated_tokens.max() * self.max_beam_width, + ) + logit_indices_for_processed_logprobs_cuda = ( + None + if not any_request_needs_processed_logprobs + else group_logit_indexer[need_processed_logprobs_indices].to( + logits_cuda.device, non_blocking=True + ) + ) + logit_indices_for_raw_logprobs_cuda = ( + None + if not any_request_needs_raw_logprobs + else group_logit_indexer[need_raw_logprobs_indices].to( + logits_cuda.device, non_blocking=True + ) + ) + group_logits_cuda_indices = logits_cuda_indexer[group_req_indices] # NB: Assuming that group_req_indices are sorted group_req_1st_index, group_req_last_index = group_req_indices[0], group_req_indices[-1] @@ -2158,26 +2411,47 @@ def _sample_batched_by_strategy( ) group_logits_cuda = logits_cuda[group_logits_cuda_indices_cuda] logit_indices_for_sampler = None + # group_logits_cuda already contains only logits for the group + group_logits_indices_for_processed_logprobs_cuda = ( + logit_indices_for_processed_logprobs_cuda + ) + group_logits_indices_for_raw_logprobs_cuda = logit_indices_for_raw_logprobs_cuda else: group_logits_cuda_indices_cuda = group_logits_cuda_indices.to( device=logits_cuda.device, non_blocking=True ) group_logits_cuda = logits_cuda logit_indices_for_sampler = group_logits_cuda_indices_cuda + # group_logits_cuda contains logits for the whole batch + # Therefore, we need indices corresponding to the whole batch + group_logits_indices_for_processed_logprobs_cuda = ( + None + if not any_request_needs_processed_logprobs + else logits_cuda_indexer[group_req_indices[group_need_processed_logprobs]].to( + logits_cuda.device, non_blocking=True + ) + ) + group_logits_indices_for_raw_logprobs_cuda = ( + None + if not any_request_needs_raw_logprobs + else logits_cuda_indexer[group_req_indices[group_need_raw_logprobs]].to( + logits_cuda.device, non_blocking=True + ) + ) group_strategies_per_step = [ # convert from per-request to per-step strat - for strat, steps in zip(group_strategies, req_num_steps[group_req_indices]) + for strat, steps in zip(group_strategies, req_num_steps[group_req_indices].tolist()) for _ in range(steps) ] - group_next_tokens_cuda, group_softmax_cuda = ( + group_next_tokens_cuda, group_softmax_cuda, group_temperature_cuda = ( self._grouped_sampler_cls.sample_grouped_strategies( strategy_key, group_strategies_per_step, group_logits_cuda, generator=generator_cuda, - return_probs=speculation_needs_probs, + return_probs=needs_probs, group_logit_indices=logit_indices_for_sampler, group_metadata=group_metadata, ) @@ -2192,12 +2466,47 @@ def _sample_batched_by_strategy( batch_next_tokens_offset_start:batch_next_tokens_offset_end ].copy_(group_next_tokens_cuda, non_blocking=True) + if any_request_needs_processed_logprobs: + assert group_logits_indices_for_processed_logprobs_cuda is not None + assert logit_indices_for_processed_logprobs_cuda is not None + assert group_softmax_cuda is not None + assert batch_logits_for_logprobs_cuda is not None + current_logits_cuda = group_logits_cuda[ + group_logits_indices_for_processed_logprobs_cuda + ] + current_softmax_cuda = group_softmax_cuda[logit_indices_for_processed_logprobs_cuda] + processed_logits_cuda = torch.where( + current_softmax_cuda > 0, current_logits_cuda, float("-inf") + ) + if group_temperature_cuda is not None: + if isinstance(group_temperature_cuda, torch.Tensor): + processed_logits_cuda /= group_temperature_cuda[ + logit_indices_for_processed_logprobs_cuda + ] + else: + processed_logits_cuda /= group_temperature_cuda + logit_indices_for_processed_logprobs_cuda += batch_next_tokens_offset_start + batch_logits_for_logprobs_cuda[logit_indices_for_processed_logprobs_cuda] = ( + processed_logits_cuda + ) + + if any_request_needs_raw_logprobs: + assert group_logits_indices_for_raw_logprobs_cuda is not None + assert logit_indices_for_raw_logprobs_cuda is not None + assert batch_logits_for_logprobs_cuda is not None + raw_logits_cuda = group_logits_cuda[group_logits_indices_for_raw_logprobs_cuda] + logit_indices_for_raw_logprobs_cuda += batch_next_tokens_offset_start + batch_logits_for_logprobs_cuda[logit_indices_for_raw_logprobs_cuda] = ( + raw_logits_cuda + ) + # Set LlmRequest.py_target_probs - if speculation_needs_probs: + if group_speculation_needs_probs_indices.size(0) > 0: assert group_softmax_cuda is not None current_offset = 0 for req_idx, steps in zip( - group_req_indices, req_num_generated_tokens[group_req_indices].tolist() + group_speculation_needs_probs_indices.tolist(), + req_num_steps[group_speculation_needs_probs_indices].tolist(), ): next_offset = current_offset + steps # using view avoids copy @@ -2220,6 +2529,7 @@ def _sample_batched_by_strategy( return _BatchedSamplingResult( batch_req_indices=batch_req_indices, batch_next_tokens_cuda_int=batch_next_tokens_cuda_int, + batch_logits_for_logprobs_cuda=batch_logits_for_logprobs_cuda, ) def _unbatch_sampling_results( @@ -2321,12 +2631,18 @@ def _select_generated_logits( # context requests do not have multiple beams yet, so beam width may differ in mixed batches req_num_beams_list = [ - req.sampling_config.beam_width if not req.is_context_init_state else 1 + req.get_beam_width_by_iter(False) if not req.is_context_init_state else 1 for req in requests ] req_num_beams = torch.tensor(req_num_beams_list, dtype=torch.int32, pin_memory=True) + # context requests do not have multiple beams yet, so beam width may differ after sampling + req_num_output_beams_list = [req.get_beam_width_by_iter(True) for req in requests] + req_num_beams_output = torch.tensor( + req_num_output_beams_list, dtype=torch.int32, pin_memory=True + ) req_num_generated_tokens = req_num_generation_steps * req_num_beams + req_num_generated_tokens_output = req_num_generation_steps * req_num_beams_output # NB: These offsets consider generated tokens _only_ (draft and target, but not context). # Filter out the context tokens below. req_offsets, sum_num_generated_tokens = _PackedStepIndexer.calculate_request_offsets( @@ -2343,6 +2659,7 @@ def _select_generated_logits( sampling_requests_metadata = SamplingRequestsMetadata( req_num_generated_tokens=req_num_generated_tokens, + req_num_generated_tokens_output=req_num_generated_tokens_output, req_num_beams=req_num_beams, req_num_steps=req_num_generation_steps, req_offsets=req_offsets, @@ -2666,6 +2983,162 @@ def _are_stop_words( return per_step + @nvtx_range("_process_logprobs") + def _process_logprobs( + self, + batched_sampling_result: _BatchedSamplingResult, + seq_slots: torch.Tensor, + requests: list[LlmRequest], + req_num_steps: torch.Tensor, + req_num_generated_tokens: torch.Tensor, + ): + assert batched_sampling_result.batch_logits_for_logprobs_cuda is not None, ( + "batch_logits_for_logprobs_cuda must be a Tensor for _process_logprobs" + ) + + all_req_indices = batched_sampling_result.batch_req_indices.tolist() + # The request indices in the shuffled batch after grouping (NB: Beam search request are handled separately) + local_group_req_indices = torch.tensor( + [ + req_id + for req_id, req_gid in enumerate(all_req_indices) + if requests[req_gid].py_num_logprobs is not None + and requests[req_gid].sampling_config.beam_width == 1 + ], + dtype=torch.int32, + ) + # Index the positions of each token in the padded 2d tensors + # NB: Using all_req_indices to allow reuse for beam search requests + padded_indexer = _PackedStepIndexer( + num_steps=req_num_generated_tokens[batched_sampling_result.batch_req_indices], + max_steps=cast(int, req_num_generated_tokens.max().item()), + req_offsets=seq_slots[batched_sampling_result.batch_req_indices] + * self.max_tokens + * self.max_beam_width, # NB: Currently either max_tokens or max_beam_width is 1 + ) + # indexer for shuffled logits after grouping + logits_cuda_indexer = _PackedStepIndexer( + num_steps=req_num_steps[batched_sampling_result.batch_req_indices], + max_steps=cast(int, req_num_steps.max().item()), + ) + + any_request_without_beam_search = local_group_req_indices.shape[0] > 0 + + if any_request_without_beam_search: + assert self.store.sampled_log_probs is not None, "sampled_log_probs must be provided" + assert self.store.sampled_log_prob_indices is not None, ( + "sampled_log_prob_indices must be provided" + ) + assert self.store.sampled_log_prob_ranks is not None, ( + "sampled_log_prob_ranks must be provided" + ) + # NB: Already begin copy here, to overlap with the remaining host code + padded_indices_cuda = padded_indexer[local_group_req_indices].to( + device=self.store.sampled_log_probs.device, non_blocking=True + ) + + # get indices of the logits after grouping + group_logits_indices_cuda = logits_cuda_indexer[local_group_req_indices].to( + device=batched_sampling_result.batch_logits_for_logprobs_cuda.device, + non_blocking=True, + ) + + # (batch_size, vocab_size) + group_logprobs_cuda = F.log_softmax( + batched_sampling_result.batch_logits_for_logprobs_cuda[group_logits_indices_cuda], + dim=-1, + ) + + # Process the topk logprobs + if self.batch_max_topk_logprobs > 0: + assert self.store.topk_vals is not None, "topk_vals must be provided" + assert self.store.topk_indices is not None, "topk_indices must be provided" + # Get the topk logprobs + # The request indices in the batch before grouping + group_req_indices = batched_sampling_result.batch_req_indices[ + local_group_req_indices + ] + topk_vals_cuda, topk_indices_cuda = torch.topk( + group_logprobs_cuda, + k=max(requests[req_id].py_num_logprobs for req_id in group_req_indices), + dim=-1, + ) + expanded_indices_cuda = padded_indices_cuda.view(-1, 1).expand( + -1, topk_vals_cuda.shape[-1] + ) + self.store.topk_vals[..., : self.batch_max_topk_logprobs].view( + self.max_num_sequences * self.max_tokens, self.batch_max_topk_logprobs + ).scatter_(dim=0, index=expanded_indices_cuda, src=topk_vals_cuda) + self.store.topk_indices[..., : self.batch_max_topk_logprobs].view( + self.max_num_sequences * self.max_tokens, self.batch_max_topk_logprobs + ).scatter_( + dim=0, index=expanded_indices_cuda, src=topk_indices_cuda.to(torch.int32) + ) + + # Process the sampled logprobs + # (batch_size, max_beam_width) + group_next_tokens_cuda = batched_sampling_result.batch_next_tokens_cuda_int[ + group_logits_indices_cuda + ][:, :1] + # Get the sampled logprobs + sampled_vals_cuda = torch.gather( + group_logprobs_cuda, dim=-1, index=group_next_tokens_cuda.view(-1, 1) + ) + # Get the sampled logprobs indices + sampled_indices_cuda = group_next_tokens_cuda.squeeze(1) + + # NB: group_logprobs_cuda is not needed anymore and the storage can be safely reused. + # sampled_rank_cuda contains the 0-based rank, it will be corrected to 1-based in handle_logprobs + group_logprobs_cuda.greater_(sampled_vals_cuda) + sampled_rank_cuda = group_logprobs_cuda.sum(dim=-1).to(torch.int32) + + sampled_vals_cuda = sampled_vals_cuda.squeeze(1) + + self.store.sampled_log_prob_indices.view( + self.max_num_sequences * self.max_tokens * self.max_beam_width + ).scatter_(dim=0, index=padded_indices_cuda, src=sampled_indices_cuda) + self.store.sampled_log_probs.view( + self.max_num_sequences * self.max_tokens * self.max_beam_width + ).scatter_(dim=0, index=padded_indices_cuda, src=sampled_vals_cuda) + self.store.sampled_log_prob_ranks.view( + self.max_num_sequences * self.max_tokens * self.max_beam_width + ).scatter_(dim=0, index=padded_indices_cuda, src=sampled_rank_cuda) + + if self._use_beam_search: + local_group_req_indices_with_beam_search = torch.tensor( + [ + req_id + for req_id, req_gid in enumerate(all_req_indices) + if requests[req_gid].py_num_logprobs is not None + and requests[req_gid].sampling_config.beam_width > 1 + ], + dtype=torch.int32, + ) + any_request_has_beam_search = local_group_req_indices_with_beam_search.shape[0] > 0 + if any_request_has_beam_search: + group_logits_indices_with_beam_search = logits_cuda_indexer[ + local_group_req_indices_with_beam_search + ] + group_logits_indices_with_beam_search_cuda = ( + group_logits_indices_with_beam_search.to( + device=batched_sampling_result.batch_next_tokens_cuda_int.device, + non_blocking=True, + ) + ) + group_next_tokens_with_beam_search_cuda = ( + batched_sampling_result.batch_next_tokens_cuda_int[ + group_logits_indices_with_beam_search_cuda + ].view(-1) + ) + padded_indices_with_beam_search_cuda = padded_indexer[ + local_group_req_indices_with_beam_search + ].to(device=self.store.sampled_log_prob_indices.device, non_blocking=True) + self.store.sampled_log_prob_indices.view(-1).scatter_( + dim=0, + index=padded_indices_with_beam_search_cuda, + src=group_next_tokens_with_beam_search_cuda, + ) + @nvtx_range("_process_requests") def _process_requests( self, @@ -2689,6 +3162,9 @@ def _process_requests( raw_logits_cuda, num_context_logits_prefix_sum=num_context_logits_prefix_sum, ) + return_log_probs = self._return_log_probs(requests) + if return_log_probs: + self._prepare_log_probs(requests) # Handle embedding bias self._apply_embedding_bias(logits_cuda, requests, sampling_requests_metadata.req_num_steps) @@ -2739,13 +3215,6 @@ def _process_requests( req_offsets=sampling_requests_metadata.req_offsets, ) - self._handle_log_probs( - requests, - logits_cuda, - logits_cuda_indexer=logits_cuda_indexer, - req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens, - ) - # Perform sampling in batches batched_sampling_result = self._sample_batched_by_strategy( logits_cuda, @@ -2759,8 +3228,18 @@ def _process_requests( req_num_generated_tokens=sampling_requests_metadata.req_num_generated_tokens, req_num_steps=sampling_requests_metadata.req_num_steps, token_dtype=new_tokens_cuda.dtype, + return_log_probs=return_log_probs, ) + if return_log_probs: + self._process_logprobs( + batched_sampling_result, + seq_slots, + requests, + sampling_requests_metadata.req_num_steps, + sampling_requests_metadata.req_num_generated_tokens_output, + ) + # Fill results into output buffers new_tokens_host = self._unbatch_sampling_results( batched_sampling_result, diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index b2c660fea73..f2156c07aad 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -266,7 +266,8 @@ def greedy_search_sampling_batch( next_tokens = torch.argmax(logits, dim=-1) softmax: Optional[torch.Tensor] = None if return_probs: - softmax = torch.softmax(logits, dim=-1) + softmax = torch.zeros_like(logits) + softmax.scatter_(1, next_tokens.unsqueeze(-1), 1.0) return next_tokens, softmax @@ -471,10 +472,10 @@ def sample( strategy: Strategy, logits: torch.Tensor, *, - generator: Optional[torch.Generator] = None, + generator: torch.Generator | None = None, group_metadata: StrategyMetadata | None = None, return_probs: bool = True, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor | None, float | None]: match strategy: case ("top_k", top_k, temperature): tokens, softmax = top_k_sampling_batch( @@ -506,6 +507,7 @@ def sample( ) case ("greedy", None): tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs) + temperature = None case ("beam_search", beam_width_in, beam_width_out, temperature): assert group_metadata is not None and isinstance(group_metadata, BeamSearchMetadata), ( "BeamSearchMetadata is required for beam_search_sampling_batch" @@ -519,7 +521,7 @@ def sample( generator=generator, return_probs=return_probs, ) - return tokens, softmax + return tokens, softmax, temperature GenericStrategyKeyType = TypeVar("GenericStrategyKeyType") @@ -545,11 +547,11 @@ def sample_grouped_strategies( strategies: list[Strategy], logits: torch.Tensor, *, - group_logit_indices: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, + group_logit_indices: torch.Tensor | None = None, + generator: torch.Generator | None = None, return_probs: bool, group_metadata: StrategyMetadata | None = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, float | torch.Tensor | None]: raise NotImplementedError @@ -579,11 +581,11 @@ def sample_grouped_strategies( strategies: list[Strategy], logits: torch.Tensor, *, - group_logit_indices: Optional[torch.Tensor] = None, - generator: Optional[torch.Generator] = None, + group_logit_indices: torch.Tensor | None = None, + generator: torch.Generator | None = None, return_probs: bool, group_metadata: StrategyMetadata | None = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, torch.Tensor | None, float | None]: if group_key[0] == "beam_search": beam_width_in = group_key[1] else: diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py index 786c953b0fd..4114c2310f0 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -141,8 +141,9 @@ def _sample_greedy_with_probs( *, group_logit_indices: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None) - new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False) + if group_logit_indices is not None: + logits = torch.index_select(logits, 0, group_logit_indices) # ensures copy + new_tokens, probs = greedy_search_sampling_batch(logits, return_probs=True) return new_tokens, probs @classmethod @@ -240,6 +241,9 @@ def computes_probs(cls) -> bool: return True class GreedyWithProbs(StrategyImplWithProbs): + def __init__(self): + self._temperature = None + @override @classmethod def from_strategies( @@ -425,6 +429,9 @@ def computes_probs(cls) -> bool: return False class GreedySampleOnly(StrategyImplSampleOnly): + def __init__(self): + self._temperature = None + @override @classmethod def from_strategies( @@ -722,7 +729,7 @@ def sample_grouped_strategies( generator: Optional[torch.Generator] = None, return_probs: bool, group_metadata: StrategyMetadata | None = None, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: if hasattr(group_key, "static_beam_width_in"): beam_width_in = group_key.static_beam_width_in else: @@ -735,9 +742,16 @@ def sample_grouped_strategies( assert return_probs == group_key.computes_probs() strategy_impl_cls = group_key - return strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device).sample( + sampling_object = strategy_impl_cls.from_strategies(strategies, cuda_device=logits.device) + next_tokens, softmax = sampling_object.sample( logits, group_logit_indices=group_logit_indices, generator=generator, group_metadata=group_metadata, ) + temperature = ( + sampling_object._temperature.unsqueeze(-1) + if sampling_object._temperature is not None + else None + ) + return next_tokens, softmax, temperature diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 8d037275a04..10067b8b284 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -215,7 +215,7 @@ class MTPSampler(TorchSampler): SampleState = SampleStateMTP - @dataclass(frozen=True, kw_only=True) + @dataclass(kw_only=True) class Store(TorchSampler.Store): new_tokens: torch.Tensor next_new_tokens: torch.Tensor diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 00f04a1d0a1..df796e7751e 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -561,6 +561,7 @@ def _deduce_max_tokens(request: GenerationRequest, cache_salt_id=request.cache_salt_id) executor_request.py_num_logprobs = request.sampling_params.logprobs executor_request.py_lora_path = py_lora_path + executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode if self._is_pytorch_backend and request.multimodal_params is not None: if request.multimodal_params.multimodal_data is not None: diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 4c15e657c10..1e4ccc7e915 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -221,7 +221,7 @@ def _get_logprob_params( self, request: GenerationRequest) -> Optional[LogprobParams]: """Store logprobs-related fields from request for the later logprob calculation.""" logprob_params = None - if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs: + if request.sampling_params.logprobs is not None or request.sampling_params.prompt_logprobs: logprob_params = LogprobParams( logprobs=request.sampling_params.logprobs, prompt_logprobs=request.sampling_params.prompt_logprobs, diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 8d33d94a7f7..26e0a3c5fa9 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -933,6 +933,21 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, logits = logits[:len(tokens)] logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1) + + # only return sampled token + if top_k == 0: + results: TokenLogprobs = [] + if tokens is not None: + for t in range(logprobs.size(0)): + token_id = tokens[t] + token_logprob = logprobs[t, token_id].item() + rank = (logprobs[t] > token_logprob).sum().item() + 1 + token_dict = { + token_id: Logprob(logprob=token_logprob, rank=rank) + } + results.append(token_dict) + return results + topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1) results: TokenLogprobs = [] @@ -961,7 +976,7 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, None) if k_prompt_logprobs and context_logits is not None else None generation_logprobs = _topk_logprobs( generation_logits, k_logprobs, output_token_ids - ) if k_logprobs and generation_logits is not None else None + ) if k_logprobs is not None and generation_logits is not None else None return LogProbsResult(prompt=prompt_logprobs, generation=generation_logprobs) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 6d3410bf3c2..8a4ee0c8b82 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -666,7 +666,7 @@ def _prepare_sampling_params( if sampling_params.prompt_logprobs and not sampling_params.return_context_logits: sampling_params.return_context_logits = True sampling_params._context_logits_auto_enabled = True - if sampling_params.logprobs and not sampling_params.return_generation_logits: + if sampling_params.logprobs is not None and not sampling_params.return_generation_logits: sampling_params.return_generation_logits = True sampling_params._generation_logits_auto_enabled = True @@ -737,7 +737,7 @@ def _check_arguments(self, prompt_len: int, query_len: int, f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True))." ) - if sampling_params.logprobs and not self.args.gather_generation_logits: + if sampling_params.logprobs is not None and not self.args.gather_generation_logits: raise ValueError( f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` " f"to be passed explicitly to the `LLM()` constructor.") diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 57bebba45ed..08de95e1958 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -6,6 +6,7 @@ import torch from pydantic import BaseModel +from strenum import StrEnum from tensorrt_llm.bindings import executor as tllme from tensorrt_llm.logger import logger @@ -46,6 +47,18 @@ class LogprobParams(NamedTuple): drop_generation_logits: bool = False +class LogprobMode(StrEnum): + RAW = "raw" + """ + Return the raw log probabilities, i.e., the log probabilities calculated directly from the model output logits. + """ + PROCESSED = "processed" + """ + Return the processed log probabilities, i.e., the log probabilities after applying sampling parameters, + such as temperature, top-k, top-p, etc. + """ + + class LogitsProcessor(ABC): """Base class for logits processor. @@ -172,7 +185,9 @@ class SamplingParams: min_p (float, optional): scale the most likely token to determine the minimum token probability. None means using C++ runtime default 0.0. Defaults to None. beam_width_array (List[int], optional): The array of beam width using in Variable-Beam-Width-Search. Defaults to None. - logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None. + logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability. + When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None. + logprobs_mode (LogprobMode, optional): The mode of log probabilities to return. Defaults to RAW. prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None. return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False. return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False. @@ -219,6 +234,7 @@ class SamplingParams: n: int = 1 best_of: Optional[int] = None use_beam_search: bool = False + logprobs_mode: LogprobMode = LogprobMode.RAW # Keep the below fields in sync with tllme.SamplingConfig or maintin the mapping table. top_k: Optional[int] = None @@ -321,6 +337,8 @@ def _validate(self): f"under the greedy decoding." ) + self.logprobs_mode = LogprobMode(self.logprobs_mode) + if self.truncate_prompt_tokens is not None and self.truncate_prompt_tokens < 1: raise ValueError( f"truncate_prompt_tokens must be >= 1, got {self.truncate_prompt_tokens}" @@ -329,8 +347,11 @@ def _validate(self): if self.guided_decoding is not None: self.guided_decoding._validate() - # correct types as users might pass in logprob=True for Top-1 logprobs - self.logprobs = self.logprobs and int(self.logprobs) + # correct types as users might pass in logprob=True for Top-0 logprobs and logprobs=False for no logprobs + if self.logprobs is False: + self.logprobs = None + if self.logprobs is True: + self.logprobs = 0 self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs) # NB: Static, because downstream code only holds instances of @@ -506,7 +527,7 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo config_kwargs = {f: getattr(self, f) for f in fields} if is_pytorch_backend: - config_kwargs["return_log_probs"] = bool(self.logprobs) + config_kwargs["return_log_probs"] = self.logprobs is not None if self.prompt_logprobs and not self.return_context_logits: logger.info( "Since prompt_logprobs is requested but return_context_logits is False, " diff --git a/tests/integration/test_lists/test-db/l0_a30.yml b/tests/integration/test_lists/test-db/l0_a30.yml index 1a6b95fbb60..f322301eb69 100644 --- a/tests/integration/test_lists/test-db/l0_a30.yml +++ b/tests/integration/test_lists/test-db/l0_a30.yml @@ -22,7 +22,7 @@ l0_a30: - unittest/_torch/modeling -k "modeling_starcoder2" - unittest/_torch/auto_deploy/unit/singlegpu - unittest/_torch/sampler/test_beam_search.py - - unittest/_torch/sampler/test_return_logits.py + - unittest/_torch/sampler/test_logits_logprobs.py - test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler] - test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler] - test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler] diff --git a/tests/unittest/_torch/sampler/test_beam_search.py b/tests/unittest/_torch/sampler/test_beam_search.py index 5a8d0fe248f..f4b1b09da37 100644 --- a/tests/unittest/_torch/sampler/test_beam_search.py +++ b/tests/unittest/_torch/sampler/test_beam_search.py @@ -125,7 +125,7 @@ def check_generation_logits(beam: CompletionOutput, def check_logprobs(beam: CompletionOutput, sampling_params: SamplingParams, valid_tokens: int | None) -> None: """Check if the logprobs have the correct shape""" - if sampling_params.logprobs: + if sampling_params.logprobs is not None: generated_tokens = valid_tokens if valid_tokens is not None else sampling_params.max_tokens assert len( beam.logprobs @@ -345,7 +345,7 @@ class GeneralTestParams: prompt_len = len(input_tokens) num_generated_tokens = 5 seq_len = prompt_len + num_generated_tokens - num_logprobs = 1 + num_logprobs = 0 seq_slot = 4 end_id = 99 batch_size = 2 @@ -541,7 +541,7 @@ def create_default_request(test_params: GeneralTestParams) -> LlmRequest: end_id=test_params.end_id, sampling_config=SamplingConfig( sampling_params._get_sampling_config()), - return_log_probs=test_params.num_logprobs > 0, + return_log_probs=test_params.num_logprobs >= 0, num_logprobs=test_params.num_logprobs, is_streaming=False) @@ -590,7 +590,7 @@ def test_create_beam_history(): num_generated_tokens = test_params.num_generated_tokens seq_slot = test_params.seq_slot vocab_size = test_params.vocab_size - num_logprobs = test_params.num_logprobs + num_logprobs = test_params.num_logprobs + 1 cache_indirection = sampler.store.cache_indirection original_tokens = sampler.store.original_tokens original_logprobs = torch.zeros( @@ -635,7 +635,11 @@ def test_create_beam_history(): # set the logprobs in the request: token_logprobs = sampler._convert_logprobs_tensor_to_list( original_logprob_indices[:beam_width, :num_generated_tokens - 1], - original_logprobs[:beam_width, :num_generated_tokens - 1]) + original_logprobs[:beam_width, :num_generated_tokens - 1], + None, + None, + None, + ) request.py_result.set_log_probs( token_logprobs, cum_log_probs=torch.zeros_like( @@ -657,9 +661,10 @@ def test_create_beam_history(): ) > 0, "Deterministic offsets must not only contain zeros. Otherwise change the seed." # set the new log probs and tokens for the beam search sampling - sampler.store.new_log_probs[ + sampler.store.sampled_log_probs[ seq_slot, :beam_width] = original_logprobs[:beam_width, - num_generated_tokens - 1, 0] + num_generated_tokens - 1, + 0:1] sampler.store.new_tokens[ 0, seq_slot, :beam_width] = original_logprob_indices[:beam_width, diff --git a/tests/unittest/_torch/sampler/test_logits_logprobs.py b/tests/unittest/_torch/sampler/test_logits_logprobs.py new file mode 100644 index 00000000000..bd85b13c618 --- /dev/null +++ b/tests/unittest/_torch/sampler/test_logits_logprobs.py @@ -0,0 +1,589 @@ +import os + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils.llm_data import llm_models_root +from utils.util import force_ampere + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm._torch.pyexecutor.sampling_utils import top_k_top_p_sampling_batch +from tensorrt_llm._torch.pyexecutor.sampling_utils_flashinfer import _StrategyImpls +from tensorrt_llm.llmapi.llm_utils import KvCacheConfig + +prompts = ["A B C"] +global_kvcache_config = KvCacheConfig( + max_tokens=10000, + enable_block_reuse=True, +) + + +@pytest.fixture(scope="module", params=[False, True]) +def gather_generation_logits_fixture(request) -> bool: + return request.param + + +@pytest.fixture(scope="module", params=[False, True]) +def gather_context_logits_fixture(request) -> bool: + return request.param + + +@pytest.fixture(scope="module", params=[False, True]) +def disable_overlap_scheduler_fixture(request) -> bool: + return request.param + + +@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"]) +def sampler_type_fixture(request) -> str: + return request.param + + +class CacheSalter: + _salt = 0 + + @classmethod + def get_salt_unique(cls) -> str: + cls._salt += 1 + return str(cls._salt) + + @classmethod + def get_salt_shared(cls) -> str: + return str(0) + + @classmethod + def get_salt(cls, reuse_cache: bool) -> str: + if reuse_cache: + salt = cls.get_salt_shared() + else: + salt = cls.get_salt_unique() + return salt + + +@pytest.fixture(scope="module") +def llm( + gather_context_logits_fixture: bool, + gather_generation_logits_fixture: bool, + sampler_type_fixture: str, + disable_overlap_scheduler_fixture: bool, +): + gather_generation_logits = gather_generation_logits_fixture + sampler_type = sampler_type_fixture + disable_overlap_scheduler = disable_overlap_scheduler_fixture + + llm = LLM( + model=os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"), + kv_cache_config=global_kvcache_config, + gather_generation_logits=gather_generation_logits, + max_batch_size=128, # reduce buffer sizes, specially for generation logits + sampler_type=sampler_type, + disable_overlap_scheduler=disable_overlap_scheduler, + ) + + # FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178. + # Remove patch below once fixed. + old_exit = LLM.__exit__ + + def _exit_with_xfail_on_timeout(self, exc_type, exc_value, traceback) -> bool: + import _pytest.outcomes + + try: + return old_exit(self, exc_type, exc_value, traceback) + except _pytest.outcomes.Failed as e: + if e.msg and "pytest-timeout" in e.msg.lower(): + pytest.xfail("Known LLM shutdown issue (https://nvbugs/5577178).") + else: + raise + + with pytest.MonkeyPatch.context() as patch: + patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout) + + with llm: + yield llm + + +@pytest.fixture(scope="module", params=[False, True]) +def simple_llm(request) -> LLM: + disable_flashinfer_sampling = request.param + llm = LLM( + model=os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0"), + max_batch_size=8, + disable_flashinfer_sampling=disable_flashinfer_sampling, + ) + return llm + + +@force_ampere # Save H100 resource +@pytest.mark.parametrize("reuse_cache", [False, True]) +@pytest.mark.parametrize("return_log_probs", [False, True]) +# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 +# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 +@pytest.mark.timeout(120, method="signal") +@pytest.mark.threadleak(enabled=False) +def test_generate_with_return_logits( + llm, + gather_context_logits_fixture: bool, + gather_generation_logits_fixture: bool, + reuse_cache: bool, + return_log_probs: bool, +): + gather_context_logits = gather_context_logits_fixture + gather_generation_logits = gather_generation_logits_fixture + + if not (gather_context_logits or gather_generation_logits or return_log_probs): # prune space + pytest.skip("Nothing to test") + + sampling_params = SamplingParams( + max_tokens=8, + return_context_logits=gather_context_logits, + return_generation_logits=gather_generation_logits, + logprobs=return_log_probs, + ) + + for output in llm.generate( + prompts, + sampling_params=sampling_params, + cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts], + ): + if gather_context_logits: + assert output.context_logits is not None + # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] + expected_len = len(prompts[0].split()) + 1 + try: + assert expected_len == output.context_logits.shape[0] + except AssertionError: + # FIXME: Remove this once the bug has been fixed + if gather_context_logits and reuse_cache: + pytest.xfail("Known bug: https://nvbugs/5577178") + raise + else: + assert output.context_logits is None + + for sequence in output.outputs: + assert sequence.length == sampling_params.max_tokens + + if gather_generation_logits: + gen_logits = sequence.generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == sampling_params.max_tokens + assert torch.argmax(gen_logits, dim=1).tolist() == sequence.token_ids + else: + assert sequence.generation_logits is None + + if return_log_probs: + assert len(sequence.logprobs) == sampling_params.max_tokens + else: + assert len(sequence.logprobs) == 0 + + +@force_ampere # Save H100 resource +@pytest.mark.parametrize("reuse_cache", [False, True]) +@pytest.mark.parametrize("return_log_probs", [False, True]) +# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 +# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 +@pytest.mark.timeout(120, method="signal") +@pytest.mark.threadleak(enabled=False) +def test_generate_async_with_return_logits( + llm, + gather_context_logits_fixture: bool, + gather_generation_logits_fixture: bool, + reuse_cache: bool, + return_log_probs: bool, +): + gather_context_logits = gather_context_logits_fixture + gather_generation_logits = gather_generation_logits_fixture + + if not (gather_context_logits or gather_generation_logits or return_log_probs): # prune space + pytest.skip("Nothing to test") + + sampling_params = SamplingParams( + max_tokens=8, + return_context_logits=gather_context_logits, + return_generation_logits=gather_generation_logits, + logprobs=return_log_probs, + ) + + for idx, output in enumerate( + llm.generate_async( + prompts[0], + sampling_params=sampling_params, + streaming=True, + cache_salt=CacheSalter.get_salt(reuse_cache), + ) + ): + if gather_context_logits: + assert output.context_logits is not None + # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] + expected_len = len(prompts[0].split()) + 1 + try: + assert expected_len == output.context_logits.shape[0] + except AssertionError: + # FIXME: Remove this once the bug has been fixed + if gather_context_logits and reuse_cache: + pytest.xfail("Known bug: https://nvbugs/5577178") + raise + else: + assert output.context_logits is None + + for sequence in output.outputs: + assert sequence.length == idx + 1 + + if gather_generation_logits: + gen_logits = sequence.generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == 1 + try: + assert torch.argmax(gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1] + except AssertionError: + # FIXME: Remove xfail once the bug is fixed + pytest.xfail("Known bug: https://nvbugs/5573238") + else: + assert sequence.generation_logits is None + + if return_log_probs: + assert len(sequence.logprobs) == idx + 1 + else: + assert len(sequence.logprobs) == 0 + + +@pytest.mark.parametrize("logprobs_k", [0, 1, 3], ids=["top_0", "top_1", "top_3"]) +@pytest.mark.parametrize("logprobs_mode", ["raw", "processed"]) +@pytest.mark.threadleak(enabled=False) +def test_sampled_token_always_in_logprobs(logprobs_k: int, logprobs_mode: str, simple_llm: LLM): + """Two scenarios: + - logprobs=0: Returns only sampled token (1 element) + - logprobs=K (K>0): Returns top-K tokens + sampled token if not in top-K (up to K+1 elements) + """ + + sampling_params = SamplingParams( + max_tokens=8, + temperature=0.7, + top_p=0.9, + logprobs=logprobs_k, + logprobs_mode=logprobs_mode, + ) + + for output in simple_llm.generate(["The future of AI is"], sampling_params=sampling_params): + print(f"\n{'=' * 80}") + print(f"Generated text: {output.outputs[0].text!r}") + print(f"Generated token IDs: {output.outputs[0].token_ids}") + + logprobs = output.outputs[0].logprobs + token_ids = output.outputs[0].token_ids + + assert len(logprobs) == sampling_params.max_tokens, ( + f"Expected {sampling_params.max_tokens} logprob entries, got {len(logprobs)}" + ) + + for token_idx, (sampled_token_id, token_logprobs) in enumerate(zip(token_ids, logprobs)): + print( + f"\n Token {token_idx}: " + f"ID={sampled_token_id}, " + f"Text={simple_llm.tokenizer.decode([sampled_token_id])!r}" + ) + + assert sampled_token_id in token_logprobs, ( + f"Token {token_idx}: Sampled token ID {sampled_token_id} not in logprobs dict: {token_logprobs.keys()}" + ) + + if logprobs_k == 0: + assert len(token_logprobs) == 1, ( + f"Token {token_idx}: Expected 1 logprob (sampled only), got {len(token_logprobs)}" + ) + else: + assert len(token_logprobs) <= logprobs_k + 1, ( + f"Token {token_idx}: Expected at most {logprobs_k + 1} logprobs, got {len(token_logprobs)}" + ) + assert len(token_logprobs) >= 1 + + sorted_tokens_by_prob = sorted( + token_logprobs.items(), key=lambda x: x[1].logprob, reverse=True + ) + + if logprobs_k > 0: + sampled_token_rank = token_logprobs[sampled_token_id].rank + sampled_in_topk = sampled_token_rank <= logprobs_k + + if not sampled_in_topk: + assert sorted_tokens_by_prob[-1][0] == sampled_token_id, ( + f"Token {token_idx}: Sampled token (ID={sampled_token_id}, rank={sampled_token_rank}) " + f"not in top-{logprobs_k}, should be last in sorted list, " + f"but last token is ID={sorted_tokens_by_prob[-1][0]}" + ) + + for rank_idx, (token_id, logprob_obj) in enumerate(sorted_tokens_by_prob, start=1): + token_text = simple_llm.tokenizer.decode([token_id]) + is_sampled = "← SAMPLED" if token_id == sampled_token_id else "" + print( + f" • Token {token_id:5d} ({token_text:15s}): " + f"logprob={logprob_obj.logprob:8.4f}, " + f"rank={logprob_obj.rank} {is_sampled}" + ) + + if logprobs_k > 0 and sampled_in_topk: + assert logprob_obj.rank == rank_idx, ( + f"Token {token_idx}: Token {token_id} rank mismatch. " + f"Expected rank {rank_idx} (by sorted position), got {logprob_obj.rank}" + ) + + print(f"{'=' * 80}\n") + + +@pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"]) +@pytest.mark.threadleak(enabled=False) +def test_logprobs_with_grouped_samplings_strategies(logprobs_k: int, simple_llm: LLM): + """Test logprobs when requests are reordered by sampling strategy grouping""" + + test_prompts = [ + "The capital of France is", + "The future of AI is", + "Hello, my name is", + "Hello, my name is", + "Write a short story about a cat", + ] + + # Causes reordering: [0,1,2,3,4] → [0,2,3,1,4] + sampling_params_list = [ + SamplingParams( + max_tokens=6, + temperature=0.8, + top_k=50, + logprobs=logprobs_k, + return_generation_logits=True, + ), + SamplingParams( + max_tokens=6, + temperature=0.8, + top_p=0.9, + logprobs=logprobs_k, + return_generation_logits=True, + ), + SamplingParams( + max_tokens=6, + temperature=0.8, + top_k=50, + logprobs=logprobs_k, + return_generation_logits=True, + ), + SamplingParams( + max_tokens=6, temperature=0.8, top_k=50, logprobs=None, return_generation_logits=True + ), + SamplingParams( + max_tokens=6, + temperature=0.8, + top_p=0.9, + logprobs=logprobs_k, + return_generation_logits=True, + ), + ] + + outputs = list(simple_llm.generate(test_prompts, sampling_params=sampling_params_list)) + + for req_idx, output in enumerate(outputs): + generation_logits = output.outputs[0].generation_logits.to(device="cuda") + token_ids = output.outputs[0].token_ids + logprobs = output.outputs[0].logprobs + if sampling_params_list[req_idx].logprobs is None: + assert len(logprobs) == 0 + continue + + assert generation_logits is not None + assert len(logprobs) == len(token_ids), "Logprobs length mismatch" + + # generation_logits might be shorter than token_ids + num_logits = len(generation_logits) + + for token_idx, (sampled_token_id, token_logprobs_dict) in enumerate( + zip(token_ids[:num_logits], logprobs[:num_logits]) + ): + returned_logprob = token_logprobs_dict[sampled_token_id].logprob + + logits_for_token = generation_logits[token_idx] + expected_logprobs = torch.nn.functional.log_softmax(logits_for_token, dim=-1).to( + device="cpu" + ) + expected_logprob = expected_logprobs[sampled_token_id].item() + print( + f"Req {req_idx}, Token {token_idx}: returned={returned_logprob:.6f}, expected={expected_logprob:.6f}" + ) + torch.testing.assert_close(returned_logprob, expected_logprob) + + +@pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"]) +@pytest.mark.threadleak(enabled=False) +def test_processed_logprobs_e2e(logprobs_k: int, simple_llm: LLM): + """Test logprobs when requests are reordered by sampling strategy grouping""" + test_prompts = [ + "The capital of France is", + "The future of AI is", + "Hello, my name is", + "Write a short story about a cat", + "Hello, my name is", + "Write a short story about a cat", + ] + + sampling_params_list = [ + # greedy decoding + SamplingParams( + max_tokens=6, + temperature=0.0, + logprobs=logprobs_k, + return_generation_logits=True, + logprobs_mode="processed", + ), + # temperature sampling + SamplingParams( + max_tokens=6, + temperature=0.8, + logprobs=logprobs_k, + return_generation_logits=True, + logprobs_mode="processed", + ), + # top-p sampling + SamplingParams( + max_tokens=6, + temperature=0.8, + top_p=0.9, + logprobs=logprobs_k, + return_generation_logits=True, + logprobs_mode="processed", + ), + # top-k sampling + SamplingParams( + max_tokens=6, + temperature=0.8, + top_k=50, + logprobs=logprobs_k, + return_generation_logits=True, + logprobs_mode="processed", + ), + # top-p sampling 2 + SamplingParams( + max_tokens=6, + temperature=0.8, + top_p=0.9, + logprobs=logprobs_k, + return_generation_logits=True, + logprobs_mode="processed", + ), + # top-p and top-k sampling + SamplingParams( + max_tokens=6, + temperature=0.8, + top_p=0.9, + top_k=50, + logprobs=logprobs_k, + return_generation_logits=True, + logprobs_mode="processed", + ), + ] + + outputs = list(simple_llm.generate(test_prompts, sampling_params=sampling_params_list)) + + for req_idx, output in enumerate(outputs): + generation_logits = output.outputs[0].generation_logits.to(device="cuda") + token_ids = output.outputs[0].token_ids + logprobs = output.outputs[0].logprobs + + assert generation_logits is not None + assert len(logprobs) == len(token_ids), "Logprobs length mismatch" + + # generation_logits might be shorter than token_ids + num_logits = len(generation_logits) + + for token_idx, token_logprobs_dict in enumerate(logprobs[:num_logits]): + assert token_ids[token_idx] in token_logprobs_dict, "Sampled token not in logprobs" + + logits_for_token = generation_logits[token_idx : token_idx + 1] + topk = sampling_params_list[req_idx].top_k + topp = sampling_params_list[req_idx].top_p + temperature = sampling_params_list[req_idx].temperature + if sampling_params_list[req_idx]._greedy_decoding: + probs = torch.zeros_like(logits_for_token) + probs[0, token_ids[token_idx]] = 1.0 + else: + topk = topk if topk is not None else logits_for_token.shape[-1] + topp = topp if topp is not None else 1.0 + temperature = temperature if temperature is not None else 1.0 + + # perform maksing top-k top-p + if simple_llm.args.disable_flashinfer_sampling: + _, probs = top_k_top_p_sampling_batch( + logits_for_token, top_k=topk, top_p=topp, temperature=temperature + ) + else: + _, probs = _StrategyImpls.StrategyImplWithProbs._sample_with_probs( + logits_for_token, + group_logit_indices=None, + top_k=torch.tensor([topk], dtype=torch.int32, device="cuda"), + top_p=torch.tensor([topp], dtype=torch.float32, device="cuda"), + temperature=torch.tensor([temperature], dtype=torch.float32, device="cuda"), + generator=None, + ) + + if temperature != 0: + logits_for_token /= temperature + adjusted_logits_for_token = torch.where(probs != 0, logits_for_token, float("-inf"))[0] + expected_logprobs = torch.nn.functional.log_softmax( + adjusted_logits_for_token, dim=-1 + ).to(device="cpu") + for logprob_token, logprob_values in token_logprobs_dict.items(): + expected_logprob = expected_logprobs[logprob_token].item() + returned_logprob = logprob_values.logprob + print( + f"Req {req_idx}, Token {token_idx}: " + f"returned={returned_logprob:.6f}, expected={expected_logprob:.6f}" + ) + torch.testing.assert_close(returned_logprob, expected_logprob) + + +@force_ampere +@pytest.mark.gpu2 +def test_logprobs_match_hf_tp2(): + model_path = os.path.join(llm_models_root(), "llama-models-v2", "TinyLlama-1.1B-Chat-v1.0") + llm = LLM( + model=model_path, + tensor_parallel_size=2, + ) + + prompts = ["The future of the AI is"] + + sampling_params = SamplingParams( + max_tokens=10, + temperature=1.0, + logprobs=0, + ) + + hf_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to( + "cuda" + ) + hf_tokenizer = AutoTokenizer.from_pretrained(model_path) + + output = list(llm.generate(prompts, sampling_params=sampling_params))[0] + + trtllm_token_ids = output.outputs[0].token_ids + trtllm_logprobs = torch.tensor( + [list(lp.values())[0].logprob for lp in output.outputs[0].logprobs] + ) + + base_ids = hf_tokenizer.encode(prompts[0], return_tensors="pt").to("cuda") + hf_logprobs = [] + + for i, token_id in enumerate(trtllm_token_ids): + if i > 0: + prev_tokens = torch.tensor([trtllm_token_ids[:i]], device="cuda") + input_ids = torch.cat([base_ids, prev_tokens], dim=1) + else: + input_ids = base_ids + with torch.no_grad(): + logits = hf_model(input_ids).logits[0, -1, :] + hf_logprobs.append(torch.log_softmax(logits, dim=-1)[token_id].item()) + + hf_logprobs = torch.tensor(hf_logprobs) + + print(f"\nTensorRT-LLM logprobs: {trtllm_logprobs}") + print(f"HuggingFace logprobs: {hf_logprobs}") + print(f"Diff: {(trtllm_logprobs - hf_logprobs).abs()}") + + torch.testing.assert_close(trtllm_logprobs, hf_logprobs, atol=0.15, rtol=0) diff --git a/tests/unittest/_torch/sampler/test_return_logits.py b/tests/unittest/_torch/sampler/test_return_logits.py deleted file mode 100644 index 9552459f516..00000000000 --- a/tests/unittest/_torch/sampler/test_return_logits.py +++ /dev/null @@ -1,239 +0,0 @@ -import os - -import pytest -import torch -from utils.llm_data import llm_models_root -from utils.util import force_ampere - -from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi.llm_utils import KvCacheConfig - -prompts = ["A B C"] -global_kvcache_config = KvCacheConfig( - max_tokens=10000, - enable_block_reuse=True, -) - - -@pytest.fixture(scope="module", params=[False, True]) -def gather_generation_logits_fixture(request) -> bool: - return request.param - - -@pytest.fixture(scope="module", params=[False, True]) -def gather_context_logits_fixture(request) -> bool: - return request.param - - -@pytest.fixture(scope="module", params=[False, True]) -def disable_overlap_scheduler_fixture(request) -> bool: - return request.param - - -@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"]) -def sampler_type_fixture(request) -> str: - return request.param - - -class CacheSalter: - - _salt = 0 - - @classmethod - def get_salt_unique(cls) -> str: - cls._salt += 1 - return str(cls._salt) - - @classmethod - def get_salt_shared(cls) -> str: - return str(0) - - @classmethod - def get_salt(cls, reuse_cache: bool) -> str: - if reuse_cache: - salt = cls.get_salt_shared() - else: - salt = cls.get_salt_unique() - return salt - - -@pytest.fixture(scope="module") -def llm( - gather_context_logits_fixture: bool, - gather_generation_logits_fixture: bool, - sampler_type_fixture: str, - disable_overlap_scheduler_fixture: bool, -): - gather_generation_logits = gather_generation_logits_fixture - sampler_type = sampler_type_fixture - disable_overlap_scheduler = disable_overlap_scheduler_fixture - - llm = LLM( - model=os.path.join(llm_models_root(), "llama-models-v2", - "TinyLlama-1.1B-Chat-v1.0"), - kv_cache_config=global_kvcache_config, - gather_generation_logits=gather_generation_logits, - max_batch_size= - 128, # reduce buffer sizes, specially for generation logits - sampler_type=sampler_type, - disable_overlap_scheduler=disable_overlap_scheduler, - ) - - # FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178. - # Remove patch below once fixed. - old_exit = LLM.__exit__ - - def _exit_with_xfail_on_timeout(self, exc_type, exc_value, - traceback) -> bool: - import _pytest.outcomes - try: - return old_exit(self, exc_type, exc_value, traceback) - except _pytest.outcomes.Failed as e: - if e.msg and "pytest-timeout" in e.msg.lower(): - pytest.xfail( - "Known LLM shutdown issue (https://nvbugs/5577178).") - else: - raise - - with pytest.MonkeyPatch.context() as patch: - patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout) - - with llm: - yield llm - - -@force_ampere # Save H100 resource -@pytest.mark.parametrize("reuse_cache", [False, True]) -@pytest.mark.parametrize("return_log_probs", [False, True]) -# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 -# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 -@pytest.mark.timeout(120, method="signal") -@pytest.mark.threadleak(enabled=False) -def test_generate_with_return_logits( - llm, - gather_context_logits_fixture: bool, - gather_generation_logits_fixture: bool, - reuse_cache: bool, - return_log_probs: bool, -): - gather_context_logits = gather_context_logits_fixture - gather_generation_logits = gather_generation_logits_fixture - - if not (gather_context_logits or gather_generation_logits - or return_log_probs): # prune space - pytest.skip("Nothing to test") - - sampling_params = SamplingParams( - max_tokens=8, - return_context_logits=gather_context_logits, - return_generation_logits=gather_generation_logits, - logprobs=return_log_probs, - ) - - for output in llm.generate( - prompts, - sampling_params=sampling_params, - cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts], - ): - if gather_context_logits: - assert output.context_logits is not None - # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] - expected_len = len(prompts[0].split()) + 1 - try: - assert expected_len == output.context_logits.shape[0] - except AssertionError: - # FIXME: Remove this once the bug has been fixed - if gather_context_logits and reuse_cache: - pytest.xfail("Known bug: https://nvbugs/5577178") - raise - else: - assert output.context_logits is None - - for sequence in output.outputs: - assert sequence.length == sampling_params.max_tokens - - if gather_generation_logits: - gen_logits = sequence.generation_logits - assert gen_logits is not None - assert gen_logits.ndim == 2 - assert gen_logits.shape[0] == sampling_params.max_tokens - assert torch.argmax(gen_logits, - dim=1).tolist() == sequence.token_ids - else: - assert sequence.generation_logits is None - - if return_log_probs: - assert len(sequence.logprobs) == sampling_params.max_tokens - else: - assert len(sequence.logprobs) == 0 - - -@force_ampere # Save H100 resource -@pytest.mark.parametrize("reuse_cache", [False, True]) -@pytest.mark.parametrize("return_log_probs", [False, True]) -# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 -# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 -@pytest.mark.timeout(120, method="signal") -@pytest.mark.threadleak(enabled=False) -def test_generate_async_with_return_logits( - llm, - gather_context_logits_fixture: bool, - gather_generation_logits_fixture: bool, - reuse_cache: bool, - return_log_probs: bool, -): - gather_context_logits = gather_context_logits_fixture - gather_generation_logits = gather_generation_logits_fixture - - if not (gather_context_logits or gather_generation_logits - or return_log_probs): # prune space - pytest.skip("Nothing to test") - - sampling_params = SamplingParams( - max_tokens=8, - return_context_logits=gather_context_logits, - return_generation_logits=gather_generation_logits, - logprobs=return_log_probs) - - for idx, output in enumerate( - llm.generate_async( - prompts[0], - sampling_params=sampling_params, - streaming=True, - cache_salt=CacheSalter.get_salt(reuse_cache), - )): - if gather_context_logits: - assert output.context_logits is not None - # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] - expected_len = len(prompts[0].split()) + 1 - try: - assert expected_len == output.context_logits.shape[0] - except AssertionError: - # FIXME: Remove this once the bug has been fixed - if gather_context_logits and reuse_cache: - pytest.xfail("Known bug: https://nvbugs/5577178") - raise - else: - assert output.context_logits is None - - for sequence in output.outputs: - assert sequence.length == idx + 1 - - if gather_generation_logits: - gen_logits = sequence.generation_logits - assert gen_logits is not None - assert gen_logits.ndim == 2 - assert gen_logits.shape[0] == 1 - try: - assert torch.argmax( - gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1] - except AssertionError: - # FIXME: Remove xfail once the bug is fixed - pytest.xfail("Known bug: https://nvbugs/5573238") - else: - assert sequence.generation_logits is None - - if return_log_probs: - assert len(sequence.logprobs) == idx + 1 - else: - assert len(sequence.logprobs) == 0 diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index db518e60859..dbfef3f6adb 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -86,7 +86,7 @@ class MockLlmRequest: is_context_init_state: bool # Torch sampler accesses this, but it does not affect this test def get_beam_width_by_iter( - self, for_next_iteration: bool + self, for_next_iteration: bool = False ) -> int: # Torch sampler accesses this, but it does not affect this test return self.sampling_config.beam_width @@ -445,12 +445,22 @@ def __init__(self, is_last_context_chunk: bool, return_context_logits: bool): def py_return_context_logits(self) -> bool: return self._return_context_logits + def get_beam_width_by_iter( + self, for_next_iteration: bool = False + ) -> int: # Torch sampler accesses this, but it does not affect this test + return self.sampling_config.beam_width + class GenRequestMock: def __init__(self, draft_len: int): self.is_context_init_state = False self.py_draft_tokens = torch.empty(draft_len, dtype=torch.int32, device=device) self.sampling_config = SamplingConfig(beam_width=1) + def get_beam_width_by_iter( + self, for_next_iteration: bool = False + ) -> int: # Torch sampler accesses this, but it does not affect this test + return self.sampling_config.beam_width + class ScheduledRequestsMock: @property def context_requests(self) -> list[LlmRequest]: diff --git a/tests/unittest/api_stability/references/sampling_params.yaml b/tests/unittest/api_stability/references/sampling_params.yaml index d6b3e6156e3..ba14e81b451 100644 --- a/tests/unittest/api_stability/references/sampling_params.yaml +++ b/tests/unittest/api_stability/references/sampling_params.yaml @@ -15,5 +15,8 @@ methods: prompt_ignore_length: annotation: Optional[int] default: null + logprobs_mode: + annotation: Literal["raw", "processed"] + default: "raw" return_annotation: None properties: {}