diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 311db068b50..989eda6b54e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1022,7 +1022,7 @@ common-files: &common_files | tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py | tests/unittest/_torch/sampler/test_beam_search.py | tests/unittest/_torch/sampler/test_best_of_n.py | - tests/unittest/_torch/sampler/test_return_logits.py | + tests/unittest/_torch/sampler/test_logits_logprobs.py | tests/unittest/_torch/sampler/test_torch_multi_arange.py | tests/unittest/_torch/sampler/test_trtllm_sampler.py | tests/unittest/_torch/speculative/test_draft_target.py | diff --git a/pyproject.toml b/pyproject.toml index b453d975e36..8aa5782d0b9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1062,7 +1062,7 @@ exclude = [ "tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py", "tests/unittest/_torch/sampler/test_beam_search.py", "tests/unittest/_torch/sampler/test_best_of_n.py", - "tests/unittest/_torch/sampler/test_return_logits.py", + "tests/unittest/_torch/sampler/test_logits_logprobs.py", "tests/unittest/_torch/sampler/test_torch_multi_arange.py", "tests/unittest/_torch/sampler/test_trtllm_sampler.py", "tests/unittest/_torch/speculative/test_draft_target.py", diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 3b8e42d2320..fe69191a6df 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -986,18 +986,94 @@ def handle_logprobs( topk_log_probs_indices = self.store.new_tokens[0, request.py_seq_slot].view( beam_width, count, -1 ) + token_log_probs = self._convert_logprobs_tensor_to_list( + topk_log_probs_indices, topk_log_probs_vals + ) else: assert beam_width == 1, "beam width must be 1 for non-beam search" - topk_log_probs_vals = request.py_topk_logprobs_vals[: count * beam_width].view( - beam_width, count, -1 - ) - topk_log_probs_indices = request.py_topk_logprobs_indices[ - : count * beam_width - ].view(beam_width, count, -1) - token_log_probs = self._convert_logprobs_tensor_to_list( - topk_log_probs_indices, topk_log_probs_vals - ) + sampled_tokens = request.get_tokens(0)[-count:] + + if request.py_num_logprobs == 0: + # Return only the sampled token's logprob + # Compute at least top-1 to determine rank + if ( + hasattr(request, "py_sampled_logprobs") + and request.py_sampled_logprobs is not None + ): + sampled_logprobs = request.py_sampled_logprobs[:count] + topk_log_probs_vals = request.py_topk_logprobs_vals[:count] # At least k=1 + topk_log_probs_indices = request.py_topk_logprobs_indices[:count] + + token_log_probs = [] + for step, ( + sampled_token, + sampled_logprob, + topk_tokens, + topk_logprobs, + ) in enumerate( + zip( + sampled_tokens, + sampled_logprobs, + topk_log_probs_indices, + topk_log_probs_vals, + ) + ): + topk_tokens_list = topk_tokens.tolist() + if sampled_token in topk_tokens_list: + # Sampled token is in top-K, use its rank + rank = topk_tokens_list.index(sampled_token) + 1 + else: + # TODO: fix rank + rank = 2 + + step_dict = { + sampled_token: Logprob(logprob=sampled_logprob.item(), rank=rank) + } + token_log_probs.append(step_dict) + else: + raise ValueError( + "py_sampled_logprobs not available when py_num_logprobs == 0" + ) + else: + # Return top-K logprobs + logprob of sampled token + sampled_logprobs = request.py_sampled_logprobs[:count] + topk_log_probs_vals = request.py_topk_logprobs_vals[:count] + topk_log_probs_indices = request.py_topk_logprobs_indices[:count] + + token_log_probs = [] + for step, ( + sampled_token, + sampled_logprob, + topk_tokens, + topk_logprobs, + ) in enumerate( + zip( + sampled_tokens, + sampled_logprobs, + topk_log_probs_indices, + topk_log_probs_vals, + ) + ): + step_dict = {} + topk_tokens_list = topk_tokens.tolist() + topk_logprobs_list = topk_logprobs.tolist() + + for rank_idx, (token, logprob) in enumerate( + zip(topk_tokens_list, topk_logprobs_list), start=1 + ): + step_dict[token] = Logprob(logprob=logprob, rank=rank_idx) + + if sampled_token not in step_dict: + # TODO: fix rank + step_dict[sampled_token] = Logprob( + logprob=sampled_logprob.item(), rank=len(topk_tokens_list) + 1 + ) + token_log_probs.append(step_dict) + + # Wrap in list for non-beam search (beam_width=1) + token_log_probs = [token_log_probs] + request.py_result.append_log_probs(token_log_probs) def finish_if_reason( @@ -2461,47 +2537,55 @@ def _process_requests( assert logits_cuda.dim() == 2, "logits should be 2D" logprobs_req_indices = [ - req_id for req_id, req in enumerate(requests) if req.py_num_logprobs + req_id for req_id, req in enumerate(requests) if req.py_num_logprobs is not None ] - logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] - logprobs_logit_indices_cuda = logprobs_logit_indices.to( - device=logits_cuda.device, non_blocking=True - ) - logprobs_cuda = F.log_softmax( - logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True), - dim=-1, - ) - topk_vals_cuda, topk_indices_cuda = torch.topk( - logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 - ) - # Use a single D2H copy to reduce overheads - topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) - topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) - topk_vals.copy_(topk_vals_cuda, non_blocking=True) - topk_indices.copy_(topk_indices_cuda, non_blocking=True) - current_offset = 0 - for req_id, steps in zip( - logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist() - ): - req = requests[req_id] - next_offset = current_offset + steps - # NB: Assigning views on memory which is being filled asynchronously - req.py_topk_logprobs_vals = topk_vals[ - current_offset:next_offset, : req.py_num_logprobs - ] - req.py_topk_logprobs_indices = topk_indices[ - current_offset:next_offset, : req.py_num_logprobs - ] - # context requests do not have multiple input beams, but they need multiple output beams - if req.is_context_init_state: - req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand( - req.sampling_config.beam_width, -1 - ) - req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand( - req.sampling_config.beam_width, -1 - ) - current_offset = next_offset + if logprobs_req_indices: + logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] + logprobs_logit_indices_cuda = logprobs_logit_indices.to( + device=logits_cuda.device, non_blocking=True + ) + logprobs_cuda = F.log_softmax( + logits_cuda[logprobs_logit_indices_cuda].to( + dtype=torch.float32, non_blocking=True + ), + dim=-1, + ) + + max_k = max( + max(1, req.py_num_logprobs) + for req in requests + if req.py_num_logprobs is not None + ) + topk_vals_cuda, topk_indices_cuda = torch.topk(logprobs_cuda, k=max_k, dim=-1) + # Use a single D2H copy to reduce overheads + topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) + topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) + topk_vals.copy_(topk_vals_cuda, non_blocking=True) + topk_indices.copy_(topk_indices_cuda, non_blocking=True) + current_offset = 0 + for req_id, steps in zip( + logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist() + ): + req = requests[req_id] + next_offset = current_offset + steps + # Store at least k=1 for all requests (including logprobs=0) to compute ranks + k_for_req = max(1, req.py_num_logprobs) + # NB: Assigning views on memory which is being filled asynchronously + req.py_topk_logprobs_vals = topk_vals[current_offset:next_offset, :k_for_req] + req.py_topk_logprobs_indices = topk_indices[ + current_offset:next_offset, :k_for_req + ] + + # context requests do not have multiple input beams, but they need multiple output beams + if req.is_context_init_state: + req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand( + req.sampling_config.beam_width, -1 + ) + req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand( + req.sampling_config.beam_width, -1 + ) + current_offset = next_offset # Perform sampling in batches batched_sampling_result = self._sample_batched_by_strategy( @@ -2517,6 +2601,59 @@ def _process_requests( token_dtype=new_tokens_cuda.dtype, ) + if return_log_probs and logprobs_req_indices: + sampled_tokens_cuda = batched_sampling_result.batch_next_tokens_cuda_int + batch_req_indices = batched_sampling_result.batch_req_indices + logprobs_req_set = set(logprobs_req_indices) + sampled_logprobs_list = [] + + # Build offsets for the GROUPED order + grouped_num_steps = req_num_steps[batch_req_indices] + grouped_offsets = torch.cat( + [ + torch.zeros((1,), dtype=torch.int32, pin_memory=True), + grouped_num_steps.cumsum(dim=0)[:-1], + ] + ) + + # Reverse mapping: original_req_id → position in grouped result + req_to_grouped_pos = { + orig_id.item(): grouped_pos for grouped_pos, orig_id in enumerate(batch_req_indices) + } + + for req_id in range(len(requests)): + if req_id in logprobs_req_set: + logprobs_idx = logprobs_req_indices.index(req_id) + + if logprobs_idx == 0: + start_offset = 0 + else: + start_offset = sum( + req_num_steps[logprobs_req_indices[:logprobs_idx]].tolist() + ) + + num_steps_this_req = req_num_steps[req_id].item() + end_offset = start_offset + num_steps_this_req + + grouped_pos = req_to_grouped_pos[req_id] + grouped_start = grouped_offsets[grouped_pos].item() + grouped_end = grouped_start + grouped_num_steps[grouped_pos].item() + + sampled_tokens_this_req = sampled_tokens_cuda[grouped_start:grouped_end] + + step_indices = torch.arange( + start_offset, end_offset, device=logprobs_cuda.device + ) + sampled_logprobs_cuda = logprobs_cuda[ + step_indices, sampled_tokens_this_req.long() + ] + + sampled_logprobs_cpu = sampled_logprobs_cuda.to(device="cpu", non_blocking=True) + sampled_logprobs_list.append((req_id, sampled_logprobs_cpu)) + + for req_id, sampled_logprobs in sampled_logprobs_list: + requests[req_id].py_sampled_logprobs = sampled_logprobs + # Fill results into output buffers new_tokens_host = self._unbatch_sampling_results( batched_sampling_result, diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 4c15e657c10..1e4ccc7e915 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -221,7 +221,7 @@ def _get_logprob_params( self, request: GenerationRequest) -> Optional[LogprobParams]: """Store logprobs-related fields from request for the later logprob calculation.""" logprob_params = None - if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs: + if request.sampling_params.logprobs is not None or request.sampling_params.prompt_logprobs: logprob_params = LogprobParams( logprobs=request.sampling_params.logprobs, prompt_logprobs=request.sampling_params.prompt_logprobs, diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 28d35c43a75..c0294a53ba5 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -907,6 +907,21 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, logits = logits[:len(tokens)] logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1) + + # only return sampled token + if top_k == 0: + results: TokenLogprobs = [] + if tokens is not None: + for t in range(logprobs.size(0)): + token_id = tokens[t] + token_logprob = logprobs[t, token_id].item() + rank = (logprobs[t] > token_logprob).sum().item() + 1 + token_dict = { + token_id: Logprob(logprob=token_logprob, rank=rank) + } + results.append(token_dict) + return results + topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1) results: TokenLogprobs = [] @@ -935,7 +950,7 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, None) if k_prompt_logprobs and context_logits is not None else None generation_logprobs = _topk_logprobs( generation_logits, k_logprobs, output_token_ids - ) if k_logprobs and generation_logits is not None else None + ) if k_logprobs is not None and generation_logits is not None else None return LogProbsResult(prompt=prompt_logprobs, generation=generation_logprobs) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index e4b060ee48f..e84a54f8224 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -632,7 +632,7 @@ def _prepare_sampling_params( if sampling_params.prompt_logprobs and not sampling_params.return_context_logits: sampling_params.return_context_logits = True sampling_params._context_logits_auto_enabled = True - if sampling_params.logprobs and not sampling_params.return_generation_logits: + if sampling_params.logprobs is not None and not sampling_params.return_generation_logits: sampling_params.return_generation_logits = True sampling_params._generation_logits_auto_enabled = True @@ -703,7 +703,7 @@ def _check_arguments(self, prompt_len: int, query_len: int, f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True))." ) - if sampling_params.logprobs and not self.args.gather_generation_logits: + if sampling_params.logprobs is not None and not self.args.gather_generation_logits: raise ValueError( f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` " f"to be passed explicitly to the `LLM()` constructor.") diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index c9d6e1f44b2..b3a5f08dc58 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -172,7 +172,8 @@ class SamplingParams: min_p (float, optional): scale the most likely token to determine the minimum token probability. None means using C++ runtime default 0.0. Defaults to None. beam_width_array (List[int], optional): The array of beam width using in Variable-Beam-Width-Search. Defaults to None. - logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None. + logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability. + When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None. prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None. return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False. return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False. @@ -501,7 +502,7 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo config_kwargs = {f: getattr(self, f) for f in fields} if is_pytorch_backend: - config_kwargs["return_log_probs"] = bool(self.logprobs) + config_kwargs["return_log_probs"] = self.logprobs is not None if self.prompt_logprobs and not self.return_context_logits: logger.info( "Since prompt_logprobs is requested but return_context_logits is False, " diff --git a/tests/integration/test_lists/test-db/l0_a30.yml b/tests/integration/test_lists/test-db/l0_a30.yml index b63ea04b5fe..73dd8d4e389 100644 --- a/tests/integration/test_lists/test-db/l0_a30.yml +++ b/tests/integration/test_lists/test-db/l0_a30.yml @@ -22,7 +22,7 @@ l0_a30: - unittest/_torch/modeling -k "modeling_starcoder2" - unittest/_torch/auto_deploy/unit/singlegpu - unittest/_torch/sampler/test_beam_search.py - - unittest/_torch/sampler/test_return_logits.py + - unittest/_torch/sampler/test_logits_logprobs.py - test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler] - test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler] - test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler] diff --git a/tests/unittest/_torch/sampler/test_logits_logprobs.py b/tests/unittest/_torch/sampler/test_logits_logprobs.py new file mode 100644 index 00000000000..ec17f51b406 --- /dev/null +++ b/tests/unittest/_torch/sampler/test_logits_logprobs.py @@ -0,0 +1,441 @@ +import os + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from utils.llm_data import llm_models_root +from utils.util import force_ampere + +from tensorrt_llm import LLM, SamplingParams +from tensorrt_llm.llmapi.llm_utils import KvCacheConfig + +prompts = ["A B C"] +global_kvcache_config = KvCacheConfig( + max_tokens=10000, + enable_block_reuse=True, +) + + +@pytest.fixture(scope="module", params=[False, True]) +def gather_generation_logits_fixture(request) -> bool: + return request.param + + +@pytest.fixture(scope="module", params=[False, True]) +def gather_context_logits_fixture(request) -> bool: + return request.param + + +@pytest.fixture(scope="module", params=[False, True]) +def disable_overlap_scheduler_fixture(request) -> bool: + return request.param + + +@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"]) +def sampler_type_fixture(request) -> str: + return request.param + + +class CacheSalter: + + _salt = 0 + + @classmethod + def get_salt_unique(cls) -> str: + cls._salt += 1 + return str(cls._salt) + + @classmethod + def get_salt_shared(cls) -> str: + return str(0) + + @classmethod + def get_salt(cls, reuse_cache: bool) -> str: + if reuse_cache: + salt = cls.get_salt_shared() + else: + salt = cls.get_salt_unique() + return salt + + +@pytest.fixture(scope="module") +def llm( + gather_context_logits_fixture: bool, + gather_generation_logits_fixture: bool, + sampler_type_fixture: str, + disable_overlap_scheduler_fixture: bool, +): + gather_generation_logits = gather_generation_logits_fixture + sampler_type = sampler_type_fixture + disable_overlap_scheduler = disable_overlap_scheduler_fixture + + llm = LLM( + model=os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0"), + kv_cache_config=global_kvcache_config, + gather_generation_logits=gather_generation_logits, + max_batch_size= + 128, # reduce buffer sizes, specially for generation logits + sampler_type=sampler_type, + disable_overlap_scheduler=disable_overlap_scheduler, + ) + + # FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178. + # Remove patch below once fixed. + old_exit = LLM.__exit__ + + def _exit_with_xfail_on_timeout(self, exc_type, exc_value, + traceback) -> bool: + import _pytest.outcomes + try: + return old_exit(self, exc_type, exc_value, traceback) + except _pytest.outcomes.Failed as e: + if e.msg and "pytest-timeout" in e.msg.lower(): + pytest.xfail( + "Known LLM shutdown issue (https://nvbugs/5577178).") + else: + raise + + with pytest.MonkeyPatch.context() as patch: + patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout) + + with llm: + yield llm + + +@force_ampere # Save H100 resource +@pytest.mark.parametrize("reuse_cache", [False, True]) +@pytest.mark.parametrize("return_log_probs", [False, True]) +# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 +# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 +@pytest.mark.timeout(120, method="signal") +@pytest.mark.threadleak(enabled=False) +def test_generate_with_return_logits( + llm, + gather_context_logits_fixture: bool, + gather_generation_logits_fixture: bool, + reuse_cache: bool, + return_log_probs: bool, +): + gather_context_logits = gather_context_logits_fixture + gather_generation_logits = gather_generation_logits_fixture + + if not (gather_context_logits or gather_generation_logits + or return_log_probs): # prune space + pytest.skip("Nothing to test") + + sampling_params = SamplingParams( + max_tokens=8, + return_context_logits=gather_context_logits, + return_generation_logits=gather_generation_logits, + logprobs=return_log_probs, + ) + + for output in llm.generate( + prompts, + sampling_params=sampling_params, + cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts], + ): + if gather_context_logits: + assert output.context_logits is not None + # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] + expected_len = len(prompts[0].split()) + 1 + try: + assert expected_len == output.context_logits.shape[0] + except AssertionError: + # FIXME: Remove this once the bug has been fixed + if gather_context_logits and reuse_cache: + pytest.xfail("Known bug: https://nvbugs/5577178") + raise + else: + assert output.context_logits is None + + for sequence in output.outputs: + assert sequence.length == sampling_params.max_tokens + + if gather_generation_logits: + gen_logits = sequence.generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == sampling_params.max_tokens + assert torch.argmax(gen_logits, + dim=1).tolist() == sequence.token_ids + else: + assert sequence.generation_logits is None + + if return_log_probs: + assert len(sequence.logprobs) == sampling_params.max_tokens + else: + assert len(sequence.logprobs) == 0 + + +@force_ampere # Save H100 resource +@pytest.mark.parametrize("reuse_cache", [False, True]) +@pytest.mark.parametrize("return_log_probs", [False, True]) +# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 +# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 +@pytest.mark.timeout(120, method="signal") +@pytest.mark.threadleak(enabled=False) +def test_generate_async_with_return_logits( + llm, + gather_context_logits_fixture: bool, + gather_generation_logits_fixture: bool, + reuse_cache: bool, + return_log_probs: bool, +): + gather_context_logits = gather_context_logits_fixture + gather_generation_logits = gather_generation_logits_fixture + + if not (gather_context_logits or gather_generation_logits + or return_log_probs): # prune space + pytest.skip("Nothing to test") + + sampling_params = SamplingParams( + max_tokens=8, + return_context_logits=gather_context_logits, + return_generation_logits=gather_generation_logits, + logprobs=return_log_probs) + + for idx, output in enumerate( + llm.generate_async( + prompts[0], + sampling_params=sampling_params, + streaming=True, + cache_salt=CacheSalter.get_salt(reuse_cache), + )): + if gather_context_logits: + assert output.context_logits is not None + # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] + expected_len = len(prompts[0].split()) + 1 + try: + assert expected_len == output.context_logits.shape[0] + except AssertionError: + # FIXME: Remove this once the bug has been fixed + if gather_context_logits and reuse_cache: + pytest.xfail("Known bug: https://nvbugs/5577178") + raise + else: + assert output.context_logits is None + + for sequence in output.outputs: + assert sequence.length == idx + 1 + + if gather_generation_logits: + gen_logits = sequence.generation_logits + assert gen_logits is not None + assert gen_logits.ndim == 2 + assert gen_logits.shape[0] == 1 + try: + assert torch.argmax( + gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1] + except AssertionError: + # FIXME: Remove xfail once the bug is fixed + pytest.xfail("Known bug: https://nvbugs/5573238") + else: + assert sequence.generation_logits is None + + if return_log_probs: + assert len(sequence.logprobs) == idx + 1 + else: + assert len(sequence.logprobs) == 0 + + +@pytest.mark.parametrize("logprobs_k", [0, 1, 3], + ids=["top_0", "top_1", "top_3"]) +def test_sampled_token_always_in_logprobs(logprobs_k: int): + """Two scenarios: + - logprobs=0: Returns only sampled token (1 element) + - logprobs=K (K>0): Returns top-K tokens + sampled token if not in top-K (up to K+1 elements) + """ + llm = LLM(model=os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0"), ) + + sampling_params = SamplingParams( + max_tokens=8, + temperature=0.7, + top_p=0.9, + logprobs=logprobs_k, + ) + + for output in llm.generate(["The future of AI is"], + sampling_params=sampling_params): + print(f"\n{'='*80}") + print(f"Generated text: {output.outputs[0].text!r}") + print(f"Generated token IDs: {output.outputs[0].token_ids}") + + logprobs = output.outputs[0].logprobs + token_ids = output.outputs[0].token_ids + + assert len(logprobs) == sampling_params.max_tokens, \ + f"Expected {sampling_params.max_tokens} logprob entries, got {len(logprobs)}" + + for token_idx, (sampled_token_id, + token_logprobs) in enumerate(zip(token_ids, logprobs)): + print( + f"\n Token {token_idx}: ID={sampled_token_id}, Text={llm.tokenizer.decode([sampled_token_id])!r}" + ) + + assert sampled_token_id in token_logprobs, \ + f"Token {token_idx}: Sampled token ID {sampled_token_id} not in logprobs dict: {token_logprobs.keys()}" + + if logprobs_k == 0: + assert len(token_logprobs) == 1, \ + f"Token {token_idx}: Expected 1 logprob (sampled only), got {len(token_logprobs)}" + else: + assert len(token_logprobs) <= logprobs_k + 1, \ + f"Token {token_idx}: Expected at most {logprobs_k + 1} logprobs, got {len(token_logprobs)}" + assert len(token_logprobs) >= 1 + + sorted_tokens_by_prob = sorted(token_logprobs.items(), + key=lambda x: x[1].logprob, + reverse=True) + + if logprobs_k > 0: + sampled_token_rank = token_logprobs[sampled_token_id].rank + sampled_in_topk = sampled_token_rank <= logprobs_k + + if not sampled_in_topk: + assert sorted_tokens_by_prob[-1][0] == sampled_token_id, \ + f"Token {token_idx}: Sampled token (ID={sampled_token_id}, rank={sampled_token_rank}) not in top-{logprobs_k}, " \ + f"should be last in sorted list, but last token is ID={sorted_tokens_by_prob[-1][0]}" + + for rank_idx, (token_id, + logprob_obj) in enumerate(sorted_tokens_by_prob, + start=1): + token_text = llm.tokenizer.decode([token_id]) + is_sampled = "← SAMPLED" if token_id == sampled_token_id else "" + print(f" • Token {token_id:5d} ({token_text:15s}): " + f"logprob={logprob_obj.logprob:8.4f}, " + f"rank={logprob_obj.rank} {is_sampled}") + + if logprobs_k > 0 and sampled_in_topk: + assert logprob_obj.rank == rank_idx, \ + f"Token {token_idx}: Token {token_id} rank mismatch. " \ + f"Expected rank {rank_idx} (by sorted position), got {logprob_obj.rank}" + + print(f"{'='*80}\n") + + +@pytest.mark.parametrize("logprobs_k", [0, 2], ids=["top_0", "top_2"]) +def test_logprobs_with_grouped_samplings_strategies(logprobs_k: int): + """Test logprobs when requests are reordered by sampling strategy grouping""" + llm = LLM( + model=os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0"), + max_batch_size=8, + ) + + test_prompts = [ + "The capital of France is", + "The future of AI is", + "Hello, my name is", + "Write a short story about a cat", + ] + + # Causes reordering: [0,1,2,3] → [0,2,1,3] + sampling_params_list = [ + SamplingParams(max_tokens=6, + temperature=0.8, + top_k=50, + logprobs=logprobs_k, + return_generation_logits=True), + SamplingParams(max_tokens=6, + temperature=0.8, + top_p=0.9, + logprobs=logprobs_k, + return_generation_logits=True), + SamplingParams(max_tokens=6, + temperature=0.8, + top_k=50, + logprobs=logprobs_k, + return_generation_logits=True), + SamplingParams(max_tokens=6, + temperature=0.8, + top_p=0.9, + logprobs=logprobs_k, + return_generation_logits=True), + ] + + outputs = list( + llm.generate(test_prompts, sampling_params=sampling_params_list)) + + for req_idx, output in enumerate(outputs): + generation_logits = output.outputs[0].generation_logits + token_ids = output.outputs[0].token_ids + logprobs = output.outputs[0].logprobs + + assert generation_logits is not None + assert len(logprobs) == len(token_ids), "Logprobs length mismatch" + + # generation_logits might be shorter than token_ids + num_logits = len(generation_logits) + + for token_idx, (sampled_token_id, token_logprobs_dict) in enumerate( + zip(token_ids[:num_logits], logprobs[:num_logits])): + returned_logprob = token_logprobs_dict[sampled_token_id].logprob + + logits_for_token = generation_logits[token_idx] + expected_logprobs = torch.nn.functional.log_softmax( + logits_for_token.to(dtype=torch.float32), dim=-1) + expected_logprob = expected_logprobs[sampled_token_id].item() + + logprob_diff = abs(returned_logprob - expected_logprob) + print( + f"Req {req_idx}, Token {token_idx}: returned={returned_logprob:.6f}, expected={expected_logprob:.6f}, diff={logprob_diff:.6f}" + ) + + assert logprob_diff < 1e-4, \ + f"Req {req_idx}, Token {token_idx}: Logprob mismatch! " \ + f"Returned {returned_logprob:.6f} but expected {expected_logprob:.6f} " \ + f"(diff={logprob_diff:.6f}). This indicates the logprob might be extracted from " \ + f"the wrong token position." + +@force_ampere +@pytest.mark.gpu2 +def test_logprobs_match_hf_tp2(): + model_path = os.path.join(llm_models_root(), "llama-models-v2", + "TinyLlama-1.1B-Chat-v1.0") + llm = LLM( + model=model_path, + tensor_parallel_size=2, + ) + + prompts = ["The future of the AI is"] + + sampling_params = SamplingParams( + max_tokens=10, + temperature=1.0, + logprobs=0, + ) + + hf_model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.bfloat16).to("cuda") + hf_tokenizer = AutoTokenizer.from_pretrained(model_path) + + output = list(llm.generate(prompts, sampling_params=sampling_params))[0] + + trtllm_token_ids = output.outputs[0].token_ids + trtllm_logprobs = torch.tensor( + [list(lp.values())[0].logprob for lp in output.outputs[0].logprobs]) + + base_ids = hf_tokenizer.encode(prompts[0], return_tensors="pt").to("cuda") + hf_logprobs = [] + + for i, token_id in enumerate(trtllm_token_ids): + if i > 0: + prev_tokens = torch.tensor([trtllm_token_ids[:i]], device="cuda") + input_ids = torch.cat([base_ids, prev_tokens], dim=1) + else: + input_ids = base_ids + with torch.no_grad(): + logits = hf_model(input_ids).logits[0, -1, :] + hf_logprobs.append(torch.log_softmax(logits, dim=-1)[token_id].item()) + + hf_logprobs = torch.tensor(hf_logprobs) + + print(f"\nTensorRT-LLM logprobs: {trtllm_logprobs}") + print(f"HuggingFace logprobs: {hf_logprobs}") + print(f"Diff: {(trtllm_logprobs - hf_logprobs).abs()}") + + max_diff = (trtllm_logprobs - hf_logprobs).abs().max().item() + assert max_diff < 0.15, f"Max logprob diff {max_diff:.4f} exceeds threshold" diff --git a/tests/unittest/_torch/sampler/test_return_logits.py b/tests/unittest/_torch/sampler/test_return_logits.py deleted file mode 100644 index 9552459f516..00000000000 --- a/tests/unittest/_torch/sampler/test_return_logits.py +++ /dev/null @@ -1,239 +0,0 @@ -import os - -import pytest -import torch -from utils.llm_data import llm_models_root -from utils.util import force_ampere - -from tensorrt_llm import LLM, SamplingParams -from tensorrt_llm.llmapi.llm_utils import KvCacheConfig - -prompts = ["A B C"] -global_kvcache_config = KvCacheConfig( - max_tokens=10000, - enable_block_reuse=True, -) - - -@pytest.fixture(scope="module", params=[False, True]) -def gather_generation_logits_fixture(request) -> bool: - return request.param - - -@pytest.fixture(scope="module", params=[False, True]) -def gather_context_logits_fixture(request) -> bool: - return request.param - - -@pytest.fixture(scope="module", params=[False, True]) -def disable_overlap_scheduler_fixture(request) -> bool: - return request.param - - -@pytest.fixture(scope="module", params=["TRTLLMSampler", "TorchSampler"]) -def sampler_type_fixture(request) -> str: - return request.param - - -class CacheSalter: - - _salt = 0 - - @classmethod - def get_salt_unique(cls) -> str: - cls._salt += 1 - return str(cls._salt) - - @classmethod - def get_salt_shared(cls) -> str: - return str(0) - - @classmethod - def get_salt(cls, reuse_cache: bool) -> str: - if reuse_cache: - salt = cls.get_salt_shared() - else: - salt = cls.get_salt_unique() - return salt - - -@pytest.fixture(scope="module") -def llm( - gather_context_logits_fixture: bool, - gather_generation_logits_fixture: bool, - sampler_type_fixture: str, - disable_overlap_scheduler_fixture: bool, -): - gather_generation_logits = gather_generation_logits_fixture - sampler_type = sampler_type_fixture - disable_overlap_scheduler = disable_overlap_scheduler_fixture - - llm = LLM( - model=os.path.join(llm_models_root(), "llama-models-v2", - "TinyLlama-1.1B-Chat-v1.0"), - kv_cache_config=global_kvcache_config, - gather_generation_logits=gather_generation_logits, - max_batch_size= - 128, # reduce buffer sizes, specially for generation logits - sampler_type=sampler_type, - disable_overlap_scheduler=disable_overlap_scheduler, - ) - - # FIXME: Sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178. - # Remove patch below once fixed. - old_exit = LLM.__exit__ - - def _exit_with_xfail_on_timeout(self, exc_type, exc_value, - traceback) -> bool: - import _pytest.outcomes - try: - return old_exit(self, exc_type, exc_value, traceback) - except _pytest.outcomes.Failed as e: - if e.msg and "pytest-timeout" in e.msg.lower(): - pytest.xfail( - "Known LLM shutdown issue (https://nvbugs/5577178).") - else: - raise - - with pytest.MonkeyPatch.context() as patch: - patch.setattr(LLM, "__exit__", _exit_with_xfail_on_timeout) - - with llm: - yield llm - - -@force_ampere # Save H100 resource -@pytest.mark.parametrize("reuse_cache", [False, True]) -@pytest.mark.parametrize("return_log_probs", [False, True]) -# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 -# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 -@pytest.mark.timeout(120, method="signal") -@pytest.mark.threadleak(enabled=False) -def test_generate_with_return_logits( - llm, - gather_context_logits_fixture: bool, - gather_generation_logits_fixture: bool, - reuse_cache: bool, - return_log_probs: bool, -): - gather_context_logits = gather_context_logits_fixture - gather_generation_logits = gather_generation_logits_fixture - - if not (gather_context_logits or gather_generation_logits - or return_log_probs): # prune space - pytest.skip("Nothing to test") - - sampling_params = SamplingParams( - max_tokens=8, - return_context_logits=gather_context_logits, - return_generation_logits=gather_generation_logits, - logprobs=return_log_probs, - ) - - for output in llm.generate( - prompts, - sampling_params=sampling_params, - cache_salt=[CacheSalter.get_salt(reuse_cache) for _ in prompts], - ): - if gather_context_logits: - assert output.context_logits is not None - # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] - expected_len = len(prompts[0].split()) + 1 - try: - assert expected_len == output.context_logits.shape[0] - except AssertionError: - # FIXME: Remove this once the bug has been fixed - if gather_context_logits and reuse_cache: - pytest.xfail("Known bug: https://nvbugs/5577178") - raise - else: - assert output.context_logits is None - - for sequence in output.outputs: - assert sequence.length == sampling_params.max_tokens - - if gather_generation_logits: - gen_logits = sequence.generation_logits - assert gen_logits is not None - assert gen_logits.ndim == 2 - assert gen_logits.shape[0] == sampling_params.max_tokens - assert torch.argmax(gen_logits, - dim=1).tolist() == sequence.token_ids - else: - assert sequence.generation_logits is None - - if return_log_probs: - assert len(sequence.logprobs) == sampling_params.max_tokens - else: - assert len(sequence.logprobs) == 0 - - -@force_ampere # Save H100 resource -@pytest.mark.parametrize("reuse_cache", [False, True]) -@pytest.mark.parametrize("return_log_probs", [False, True]) -# FIXME: sometimes LLM shutdown hangs, might be related to https://nvbugs/5577178 -# NB: Timeout covers fixtures https://github.com/pytest-dev/pytest-timeout/issues/134 -@pytest.mark.timeout(120, method="signal") -@pytest.mark.threadleak(enabled=False) -def test_generate_async_with_return_logits( - llm, - gather_context_logits_fixture: bool, - gather_generation_logits_fixture: bool, - reuse_cache: bool, - return_log_probs: bool, -): - gather_context_logits = gather_context_logits_fixture - gather_generation_logits = gather_generation_logits_fixture - - if not (gather_context_logits or gather_generation_logits - or return_log_probs): # prune space - pytest.skip("Nothing to test") - - sampling_params = SamplingParams( - max_tokens=8, - return_context_logits=gather_context_logits, - return_generation_logits=gather_generation_logits, - logprobs=return_log_probs) - - for idx, output in enumerate( - llm.generate_async( - prompts[0], - sampling_params=sampling_params, - streaming=True, - cache_salt=CacheSalter.get_salt(reuse_cache), - )): - if gather_context_logits: - assert output.context_logits is not None - # NOTE: prompt_token_ids of "A B C" becomes [1, 319, 350, 315] - expected_len = len(prompts[0].split()) + 1 - try: - assert expected_len == output.context_logits.shape[0] - except AssertionError: - # FIXME: Remove this once the bug has been fixed - if gather_context_logits and reuse_cache: - pytest.xfail("Known bug: https://nvbugs/5577178") - raise - else: - assert output.context_logits is None - - for sequence in output.outputs: - assert sequence.length == idx + 1 - - if gather_generation_logits: - gen_logits = sequence.generation_logits - assert gen_logits is not None - assert gen_logits.ndim == 2 - assert gen_logits.shape[0] == 1 - try: - assert torch.argmax( - gen_logits, dim=1).tolist()[0] == sequence.token_ids[-1] - except AssertionError: - # FIXME: Remove xfail once the bug is fixed - pytest.xfail("Known bug: https://nvbugs/5573238") - else: - assert sequence.generation_logits is None - - if return_log_probs: - assert len(sequence.logprobs) == idx + 1 - else: - assert len(sequence.logprobs) == 0