-
Notifications
You must be signed in to change notification settings - Fork 2k
[None][fix] Change PyT to always include sampled logprob #9374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -986,18 +986,94 @@ def handle_logprobs( | |
| topk_log_probs_indices = self.store.new_tokens[0, request.py_seq_slot].view( | ||
| beam_width, count, -1 | ||
| ) | ||
| token_log_probs = self._convert_logprobs_tensor_to_list( | ||
| topk_log_probs_indices, topk_log_probs_vals | ||
| ) | ||
| else: | ||
| assert beam_width == 1, "beam width must be 1 for non-beam search" | ||
| topk_log_probs_vals = request.py_topk_logprobs_vals[: count * beam_width].view( | ||
| beam_width, count, -1 | ||
| ) | ||
| topk_log_probs_indices = request.py_topk_logprobs_indices[ | ||
| : count * beam_width | ||
| ].view(beam_width, count, -1) | ||
|
|
||
| token_log_probs = self._convert_logprobs_tensor_to_list( | ||
| topk_log_probs_indices, topk_log_probs_vals | ||
| ) | ||
| sampled_tokens = request.get_tokens(0)[-count:] | ||
|
|
||
| if request.py_num_logprobs == 0: | ||
| # Return only the sampled token's logprob | ||
| # Compute at least top-1 to determine rank | ||
| if ( | ||
| hasattr(request, "py_sampled_logprobs") | ||
| and request.py_sampled_logprobs is not None | ||
| ): | ||
| sampled_logprobs = request.py_sampled_logprobs[:count] | ||
| topk_log_probs_vals = request.py_topk_logprobs_vals[:count] # At least k=1 | ||
| topk_log_probs_indices = request.py_topk_logprobs_indices[:count] | ||
|
|
||
| token_log_probs = [] | ||
| for step, ( | ||
| sampled_token, | ||
| sampled_logprob, | ||
| topk_tokens, | ||
| topk_logprobs, | ||
| ) in enumerate( | ||
| zip( | ||
| sampled_tokens, | ||
| sampled_logprobs, | ||
| topk_log_probs_indices, | ||
| topk_log_probs_vals, | ||
| ) | ||
| ): | ||
| topk_tokens_list = topk_tokens.tolist() | ||
| if sampled_token in topk_tokens_list: | ||
| # Sampled token is in top-K, use its rank | ||
| rank = topk_tokens_list.index(sampled_token) + 1 | ||
| else: | ||
| # TODO: fix rank | ||
| rank = 2 | ||
|
|
||
| step_dict = { | ||
| sampled_token: Logprob(logprob=sampled_logprob.item(), rank=rank) | ||
| } | ||
| token_log_probs.append(step_dict) | ||
| else: | ||
| raise ValueError( | ||
| "py_sampled_logprobs not available when py_num_logprobs == 0" | ||
| ) | ||
| else: | ||
| # Return top-K logprobs + logprob of sampled token | ||
| sampled_logprobs = request.py_sampled_logprobs[:count] | ||
| topk_log_probs_vals = request.py_topk_logprobs_vals[:count] | ||
| topk_log_probs_indices = request.py_topk_logprobs_indices[:count] | ||
|
|
||
| token_log_probs = [] | ||
| for step, ( | ||
| sampled_token, | ||
| sampled_logprob, | ||
| topk_tokens, | ||
| topk_logprobs, | ||
| ) in enumerate( | ||
| zip( | ||
| sampled_tokens, | ||
| sampled_logprobs, | ||
| topk_log_probs_indices, | ||
| topk_log_probs_vals, | ||
| ) | ||
| ): | ||
| step_dict = {} | ||
| topk_tokens_list = topk_tokens.tolist() | ||
| topk_logprobs_list = topk_logprobs.tolist() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you could merge the case |
||
|
|
||
| for rank_idx, (token, logprob) in enumerate( | ||
| zip(topk_tokens_list, topk_logprobs_list), start=1 | ||
| ): | ||
| step_dict[token] = Logprob(logprob=logprob, rank=rank_idx) | ||
|
|
||
| if sampled_token not in step_dict: | ||
| # TODO: fix rank | ||
| step_dict[sampled_token] = Logprob( | ||
| logprob=sampled_logprob.item(), rank=len(topk_tokens_list) + 1 | ||
| ) | ||
| token_log_probs.append(step_dict) | ||
|
|
||
| # Wrap in list for non-beam search (beam_width=1) | ||
| token_log_probs = [token_log_probs] | ||
|
|
||
| request.py_result.append_log_probs(token_log_probs) | ||
|
|
||
| def finish_if_reason( | ||
|
|
@@ -2461,47 +2537,55 @@ def _process_requests( | |
| assert logits_cuda.dim() == 2, "logits should be 2D" | ||
|
|
||
| logprobs_req_indices = [ | ||
| req_id for req_id, req in enumerate(requests) if req.py_num_logprobs | ||
| req_id for req_id, req in enumerate(requests) if req.py_num_logprobs is not None | ||
| ] | ||
| logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] | ||
| logprobs_logit_indices_cuda = logprobs_logit_indices.to( | ||
| device=logits_cuda.device, non_blocking=True | ||
| ) | ||
| logprobs_cuda = F.log_softmax( | ||
| logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True), | ||
| dim=-1, | ||
| ) | ||
| topk_vals_cuda, topk_indices_cuda = torch.topk( | ||
| logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 | ||
| ) | ||
| # Use a single D2H copy to reduce overheads | ||
| topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) | ||
| topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) | ||
| topk_vals.copy_(topk_vals_cuda, non_blocking=True) | ||
| topk_indices.copy_(topk_indices_cuda, non_blocking=True) | ||
| current_offset = 0 | ||
| for req_id, steps in zip( | ||
| logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist() | ||
| ): | ||
| req = requests[req_id] | ||
| next_offset = current_offset + steps | ||
| # NB: Assigning views on memory which is being filled asynchronously | ||
| req.py_topk_logprobs_vals = topk_vals[ | ||
| current_offset:next_offset, : req.py_num_logprobs | ||
| ] | ||
| req.py_topk_logprobs_indices = topk_indices[ | ||
| current_offset:next_offset, : req.py_num_logprobs | ||
| ] | ||
|
|
||
| # context requests do not have multiple input beams, but they need multiple output beams | ||
| if req.is_context_init_state: | ||
| req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand( | ||
| req.sampling_config.beam_width, -1 | ||
| ) | ||
| req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand( | ||
| req.sampling_config.beam_width, -1 | ||
| ) | ||
| current_offset = next_offset | ||
| if logprobs_req_indices: | ||
| logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] | ||
| logprobs_logit_indices_cuda = logprobs_logit_indices.to( | ||
| device=logits_cuda.device, non_blocking=True | ||
| ) | ||
| logprobs_cuda = F.log_softmax( | ||
| logits_cuda[logprobs_logit_indices_cuda].to( | ||
| dtype=torch.float32, non_blocking=True | ||
| ), | ||
| dim=-1, | ||
| ) | ||
|
|
||
| max_k = max( | ||
| max(1, req.py_num_logprobs) | ||
| for req in requests | ||
| if req.py_num_logprobs is not None | ||
| ) | ||
| topk_vals_cuda, topk_indices_cuda = torch.topk(logprobs_cuda, k=max_k, dim=-1) | ||
| # Use a single D2H copy to reduce overheads | ||
| topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) | ||
| topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) | ||
| topk_vals.copy_(topk_vals_cuda, non_blocking=True) | ||
| topk_indices.copy_(topk_indices_cuda, non_blocking=True) | ||
| current_offset = 0 | ||
| for req_id, steps in zip( | ||
| logprobs_req_indices, req_num_generated_tokens[logprobs_req_indices].tolist() | ||
| ): | ||
| req = requests[req_id] | ||
| next_offset = current_offset + steps | ||
| # Store at least k=1 for all requests (including logprobs=0) to compute ranks | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate why it is necessary to use k>=1 for rank computation? I would have expected that it should work with 0 as well. |
||
| k_for_req = max(1, req.py_num_logprobs) | ||
| # NB: Assigning views on memory which is being filled asynchronously | ||
| req.py_topk_logprobs_vals = topk_vals[current_offset:next_offset, :k_for_req] | ||
| req.py_topk_logprobs_indices = topk_indices[ | ||
| current_offset:next_offset, :k_for_req | ||
| ] | ||
|
|
||
| # context requests do not have multiple input beams, but they need multiple output beams | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the beam search part is obsolete and may be removed |
||
| if req.is_context_init_state: | ||
| req.py_topk_logprobs_vals = req.py_topk_logprobs_vals.expand( | ||
| req.sampling_config.beam_width, -1 | ||
| ) | ||
| req.py_topk_logprobs_indices = req.py_topk_logprobs_indices.expand( | ||
| req.sampling_config.beam_width, -1 | ||
| ) | ||
| current_offset = next_offset | ||
|
|
||
| # Perform sampling in batches | ||
| batched_sampling_result = self._sample_batched_by_strategy( | ||
|
|
@@ -2517,6 +2601,59 @@ def _process_requests( | |
| token_dtype=new_tokens_cuda.dtype, | ||
| ) | ||
|
|
||
| if return_log_probs and logprobs_req_indices: | ||
| sampled_tokens_cuda = batched_sampling_result.batch_next_tokens_cuda_int | ||
| batch_req_indices = batched_sampling_result.batch_req_indices | ||
| logprobs_req_set = set(logprobs_req_indices) | ||
| sampled_logprobs_list = [] | ||
|
|
||
| # Build offsets for the GROUPED order | ||
| grouped_num_steps = req_num_steps[batch_req_indices] | ||
| grouped_offsets = torch.cat( | ||
| [ | ||
| torch.zeros((1,), dtype=torch.int32, pin_memory=True), | ||
| grouped_num_steps.cumsum(dim=0)[:-1], | ||
| ] | ||
| ) | ||
|
|
||
| # Reverse mapping: original_req_id → position in grouped result | ||
| req_to_grouped_pos = { | ||
| orig_id.item(): grouped_pos for grouped_pos, orig_id in enumerate(batch_req_indices) | ||
| } | ||
|
|
||
| for req_id in range(len(requests)): | ||
| if req_id in logprobs_req_set: | ||
| logprobs_idx = logprobs_req_indices.index(req_id) | ||
|
|
||
| if logprobs_idx == 0: | ||
| start_offset = 0 | ||
| else: | ||
| start_offset = sum( | ||
| req_num_steps[logprobs_req_indices[:logprobs_idx]].tolist() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe using |
||
| ) | ||
|
|
||
| num_steps_this_req = req_num_steps[req_id].item() | ||
| end_offset = start_offset + num_steps_this_req | ||
|
|
||
| grouped_pos = req_to_grouped_pos[req_id] | ||
| grouped_start = grouped_offsets[grouped_pos].item() | ||
| grouped_end = grouped_start + grouped_num_steps[grouped_pos].item() | ||
|
|
||
| sampled_tokens_this_req = sampled_tokens_cuda[grouped_start:grouped_end] | ||
|
|
||
| step_indices = torch.arange( | ||
| start_offset, end_offset, device=logprobs_cuda.device | ||
| ) | ||
| sampled_logprobs_cuda = logprobs_cuda[ | ||
| step_indices, sampled_tokens_this_req.long() | ||
| ] | ||
|
|
||
| sampled_logprobs_cpu = sampled_logprobs_cuda.to(device="cpu", non_blocking=True) | ||
| sampled_logprobs_list.append((req_id, sampled_logprobs_cpu)) | ||
|
|
||
| for req_id, sampled_logprobs in sampled_logprobs_list: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might integrate the second loop into the first one and drop the |
||
| requests[req_id].py_sampled_logprobs = sampled_logprobs | ||
|
|
||
| # Fill results into output buffers | ||
| new_tokens_host = self._unbatch_sampling_results( | ||
| batched_sampling_result, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rank calculation could be done, by calculating the rank in
_process_requestswhere you have access to all logprobs. You can then pass it forward asrequest.py_sampled_ranksimilar to how you passrequest.py_sampled_logprobs