Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1022,7 +1022,7 @@ common-files: &common_files |
tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py |
tests/unittest/_torch/sampler/test_beam_search.py |
tests/unittest/_torch/sampler/test_best_of_n.py |
tests/unittest/_torch/sampler/test_return_logits.py |
tests/unittest/_torch/sampler/test_logits_logprobs.py |
tests/unittest/_torch/sampler/test_torch_multi_arange.py |
tests/unittest/_torch/sampler/test_trtllm_sampler.py |
tests/unittest/_torch/speculative/test_draft_target.py |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ exclude = [
"tests/unittest/_torch/ray_orchestrator/single_gpu/test_cache_transceiver_comm.py",
"tests/unittest/_torch/sampler/test_beam_search.py",
"tests/unittest/_torch/sampler/test_best_of_n.py",
"tests/unittest/_torch/sampler/test_return_logits.py",
"tests/unittest/_torch/sampler/test_logits_logprobs.py",
"tests/unittest/_torch/sampler/test_torch_multi_arange.py",
"tests/unittest/_torch/sampler/test_trtllm_sampler.py",
"tests/unittest/_torch/speculative/test_draft_target.py",
Expand Down
233 changes: 185 additions & 48 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,18 +986,94 @@ def handle_logprobs(
topk_log_probs_indices = self.store.new_tokens[0, request.py_seq_slot].view(
beam_width, count, -1
)
token_log_probs = self._convert_logprobs_tensor_to_list(
topk_log_probs_indices, topk_log_probs_vals
)
else:
assert beam_width == 1, "beam width must be 1 for non-beam search"
topk_log_probs_vals = request.py_topk_logprobs_vals[: count * beam_width].view(
beam_width, count, -1
)
topk_log_probs_indices = request.py_topk_logprobs_indices[
: count * beam_width
].view(beam_width, count, -1)

token_log_probs = self._convert_logprobs_tensor_to_list(
topk_log_probs_indices, topk_log_probs_vals
)
sampled_tokens = request.get_tokens(0)[-count:]

if request.py_num_logprobs == 0:
# Return only the sampled token's logprob
# Compute at least top-1 to determine rank
if (
hasattr(request, "py_sampled_logprobs")
and request.py_sampled_logprobs is not None
):
sampled_logprobs = request.py_sampled_logprobs[:count]
topk_log_probs_vals = request.py_topk_logprobs_vals[:count] # At least k=1
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]

token_log_probs = []
for step, (
sampled_token,
sampled_logprob,
topk_tokens,
topk_logprobs,
) in enumerate(
zip(
sampled_tokens,
sampled_logprobs,
topk_log_probs_indices,
topk_log_probs_vals,
)
):
topk_tokens_list = topk_tokens.tolist()
if sampled_token in topk_tokens_list:
# Sampled token is in top-K, use its rank
rank = topk_tokens_list.index(sampled_token) + 1
else:
# TODO: fix rank
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rank calculation could be done, by calculating the rank in_process_requests where you have access to all logprobs. You can then pass it forward as request.py_sampled_rank similar to how you pass request.py_sampled_logprobs

rank = 2

step_dict = {
sampled_token: Logprob(logprob=sampled_logprob.item(), rank=rank)
}
token_log_probs.append(step_dict)
else:
raise ValueError(
"py_sampled_logprobs not available when py_num_logprobs == 0"
)
else:
# Return top-K logprobs + logprob of sampled token
sampled_logprobs = request.py_sampled_logprobs[:count]
topk_log_probs_vals = request.py_topk_logprobs_vals[:count]
topk_log_probs_indices = request.py_topk_logprobs_indices[:count]

token_log_probs = []
for step, (
sampled_token,
sampled_logprob,
topk_tokens,
topk_logprobs,
) in enumerate(
zip(
sampled_tokens,
sampled_logprobs,
topk_log_probs_indices,
topk_log_probs_vals,
)
):
step_dict = {}
topk_tokens_list = topk_tokens.tolist()
topk_logprobs_list = topk_logprobs.tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could merge the case request.py_num_logprobs == 0 with this one, by iterating only forrequest.py_num_logprobs steps. Essentially, this loop would then be empty


for rank_idx, (token, logprob) in enumerate(
zip(topk_tokens_list, topk_logprobs_list), start=1
):
step_dict[token] = Logprob(logprob=logprob, rank=rank_idx)

if sampled_token not in step_dict:
# TODO: fix rank
step_dict[sampled_token] = Logprob(
logprob=sampled_logprob.item(), rank=len(topk_tokens_list) + 1
)
token_log_probs.append(step_dict)

# Wrap in list for non-beam search (beam_width=1)
token_log_probs = [token_log_probs]

request.py_result.append_log_probs(token_log_probs)

def finish_if_reason(
Expand Down Expand Up @@ -2461,47 +2537,55 @@ def _process_requests(
assert logits_cuda.dim() == 2, "logits should be 2D"

logprobs_req_indices = [
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs is not None
]
logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices]
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
device=logits_cuda.device, non_blocking=True
)
logprobs_cuda = F.log_softmax(
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
dim=-1,
)
topk_vals_cuda, topk_indices_cuda = torch.topk(
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
)
# Use a single D2H copy to reduce overheads
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
current_offset = 0
for req_id, steps in zip(
logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist()
):
req = requests[req_id]
next_offset = current_offset + steps
# NB: Assigning views on memory which is being filled asynchronously
req.py_topk_logprobs_vals = topk_vals[
current_offset:next_offset, : req.py_num_logprobs
]
req.py_topk_logprobs_indices = topk_indices[
current_offset:next_offset, : req.py_num_logprobs
]

# context requests do not have multiple input beams, but they need multiple output beams
if req.is_context_init_state:
req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand(
req.sampling_config.beam_width, -1
)
req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand(
req.sampling_config.beam_width, -1
)
current_offset = next_offset
if logprobs_req_indices:
logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices]
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
device=logits_cuda.device, non_blocking=True
)
logprobs_cuda = F.log_softmax(
logits_cuda[logprobs_logit_indices_cuda].to(
dtype=torch.float32, non_blocking=True
),
dim=-1,
)

max_k = max(
max(1, req.py_num_logprobs)
for req in requests
if req.py_num_logprobs is not None
)
topk_vals_cuda, topk_indices_cuda = torch.topk(logprobs_cuda, k=max_k, dim=-1)
# Use a single D2H copy to reduce overheads
topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True)
topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True)
topk_vals.copy_(topk_vals_cuda, non_blocking=True)
topk_indices.copy_(topk_indices_cuda, non_blocking=True)
current_offset = 0
for req_id, steps in zip(
logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist()
):
req = requests[req_id]
next_offset = current_offset + steps
# Store at least k=1 for all requests (including logprobs=0) to compute ranks
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate why it is necessary to use k>=1 for rank computation? I would have expected that it should work with 0 as well.

k_for_req = max(1, req.py_num_logprobs)
# NB: Assigning views on memory which is being filled asynchronously
req.py_topk_logprobs_vals = topk_vals[current_offset:next_offset, :k_for_req]
req.py_topk_logprobs_indices = topk_indices[
current_offset:next_offset, :k_for_req
]

# context requests do not have multiple input beams, but they need multiple output beams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the beam search part is obsolete and may be removed

if req.is_context_init_state:
req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand(
req.sampling_config.beam_width, -1
)
req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand(
req.sampling_config.beam_width, -1
)
current_offset = next_offset

# Perform sampling in batches
batched_sampling_result = self._sample_batched_by_strategy(
Expand All @@ -2517,6 +2601,59 @@ def _process_requests(
token_dtype=new_tokens_cuda.dtype,
)

if return_log_probs and logprobs_req_indices:
sampled_tokens_cuda = batched_sampling_result.batch_next_tokens_cuda_int
batch_req_indices = batched_sampling_result.batch_req_indices
logprobs_req_set = set(logprobs_req_indices)
sampled_logprobs_list = []

# Build offsets for the GROUPED order
grouped_num_steps = req_num_steps[batch_req_indices]
grouped_offsets = torch.cat(
[
torch.zeros((1,), dtype=torch.int32, pin_memory=True),
grouped_num_steps.cumsum(dim=0)[:-1],
]
)

# Reverse mapping: original_req_id → position in grouped result
req_to_grouped_pos = {
orig_id.item(): grouped_pos for grouped_pos, orig_id in enumerate(batch_req_indices)
}

for req_id in range(len(requests)):
if req_id in logprobs_req_set:
logprobs_idx = logprobs_req_indices.index(req_id)

if logprobs_idx == 0:
start_offset = 0
else:
start_offset = sum(
req_num_steps[logprobs_req_indices[:logprobs_idx]].tolist()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe using req_num_steps[logprobs_req_indices[:logprobs_idx]].sum().item() might be more efficient. This may also allow you to remove the if-else condition

)

num_steps_this_req = req_num_steps[req_id].item()
end_offset = start_offset + num_steps_this_req

grouped_pos = req_to_grouped_pos[req_id]
grouped_start = grouped_offsets[grouped_pos].item()
grouped_end = grouped_start + grouped_num_steps[grouped_pos].item()

sampled_tokens_this_req = sampled_tokens_cuda[grouped_start:grouped_end]

step_indices = torch.arange(
start_offset, end_offset, device=logprobs_cuda.device
)
sampled_logprobs_cuda = logprobs_cuda[
step_indices, sampled_tokens_this_req.long()
]

sampled_logprobs_cpu = sampled_logprobs_cuda.to(device="cpu", non_blocking=True)
sampled_logprobs_list.append((req_id, sampled_logprobs_cpu))

for req_id, sampled_logprobs in sampled_logprobs_list:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might integrate the second loop into the first one and drop the sampled_logprobs_list by directly assigning requests[req_id].py_sampled_logprobs = sampled_logprobs_cpu

requests[req_id].py_sampled_logprobs = sampled_logprobs

# Fill results into output buffers
new_tokens_host = self._unbatch_sampling_results(
batched_sampling_result,
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def _get_logprob_params(
self, request: GenerationRequest) -> Optional[LogprobParams]:
"""Store logprobs-related fields from request for the later logprob calculation."""
logprob_params = None
if request.sampling_params.logprobs or request.sampling_params.prompt_logprobs:
if request.sampling_params.logprobs is not None or request.sampling_params.prompt_logprobs:
logprob_params = LogprobParams(
logprobs=request.sampling_params.logprobs,
prompt_logprobs=request.sampling_params.prompt_logprobs,
Expand Down
17 changes: 16 additions & 1 deletion tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,21 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
logits = logits[:len(tokens)]

logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)

# only return sampled token
if top_k == 0:
results: TokenLogprobs = []
if tokens is not None:
for t in range(logprobs.size(0)):
token_id = tokens[t]
token_logprob = logprobs[t, token_id].item()
rank = (logprobs[t] > token_logprob).sum().item() + 1
token_dict = {
token_id: Logprob(logprob=token_logprob, rank=rank)
}
results.append(token_dict)
return results

topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)

results: TokenLogprobs = []
Expand Down Expand Up @@ -935,7 +950,7 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
None) if k_prompt_logprobs and context_logits is not None else None
generation_logprobs = _topk_logprobs(
generation_logits, k_logprobs, output_token_ids
) if k_logprobs and generation_logits is not None else None
) if k_logprobs is not None and generation_logits is not None else None

return LogProbsResult(prompt=prompt_logprobs,
generation=generation_logprobs)
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def _prepare_sampling_params(
if sampling_params.prompt_logprobs and not sampling_params.return_context_logits:
sampling_params.return_context_logits = True
sampling_params._context_logits_auto_enabled = True
if sampling_params.logprobs and not sampling_params.return_generation_logits:
if sampling_params.logprobs is not None and not sampling_params.return_generation_logits:
sampling_params.return_generation_logits = True
sampling_params._generation_logits_auto_enabled = True

Expand Down Expand Up @@ -703,7 +703,7 @@ def _check_arguments(self, prompt_len: int, query_len: int,
f"Example: LLM(..., build_config=BuildConfig(gather_context_logits=True))."
)

if sampling_params.logprobs and not self.args.gather_generation_logits:
if sampling_params.logprobs is not None and not self.args.gather_generation_logits:
raise ValueError(
f"`sampling_params.logprobs={sampling_params.logprobs}` requires `gather_generation_logits=True` "
f"to be passed explicitly to the `LLM()` constructor.")
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ class SamplingParams:
min_p (float, optional): scale the most likely token to determine the minimum token probability. None means using C++ runtime default 0.0. Defaults to None.
beam_width_array (List[int], optional): The array of beam width using in Variable-Beam-Width-Search. Defaults to None.

logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None.
logprobs (int, optional): Number of log probabilities to return per output token. When set to 0, return only the sampled token's log probability.
When set to K>0, return top-K log probabilities + the sampled token's log probability (last entry) if it's not in the Top-K. Defaults to None.
prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None.
return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False.
return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False.
Expand Down Expand Up @@ -501,7 +502,7 @@ def _get_output_config(self, is_pytorch_backend: bool = False) -> tllme.OutputCo
config_kwargs = {f: getattr(self, f) for f in fields}

if is_pytorch_backend:
config_kwargs["return_log_probs"] = bool(self.logprobs)
config_kwargs["return_log_probs"] = self.logprobs is not None
if self.prompt_logprobs and not self.return_context_logits:
logger.info(
"Since prompt_logprobs is requested but return_context_logits is False, "
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_lists/test-db/l0_a30.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ l0_a30:
- unittest/_torch/modeling -k "modeling_starcoder2"
- unittest/_torch/auto_deploy/unit/singlegpu
- unittest/_torch/sampler/test_beam_search.py
- unittest/_torch/sampler/test_return_logits.py
- unittest/_torch/sampler/test_logits_logprobs.py
- test_e2e.py::test_openai_completions_with_logit_bias[torch_sampler]
- test_e2e.py::test_openai_chat_with_logit_bias[torch_sampler]
- test_e2e.py::test_openai_completions_with_logit_bias[trtllm_sampler]
Expand Down
Loading
Loading