diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 3475963bb01..4acbe1b6335 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -832,6 +832,12 @@ def is_generation_model(self) -> bool: class Store: new_tokens: torch.Tensor """Shape: See cpp DecoderState.getAllNewTokens()""" + max_lengths_tensor: torch.Tensor + """Shape: batch_size + Usage: Stores the maximum lengths for each request""" + end_ids: torch.Tensor + """Shape: batch_size + Usage: Stores the end ids for each request""" finish_reasons: torch.Tensor """Shape: max_tokens, batch_size, beam_width Usage: Stores the currently estimated finish_reasons for each request""" @@ -881,10 +887,14 @@ def _create_store(self) -> Store: first_finish_reasons=int_tensor( self.CACHE_INDIRECTION_SHAPE[:-1], ), + max_lengths_tensor=int_tensor(self.max_num_sequences), + end_ids=int_tensor(self.max_num_sequences), ) else: return self.Store( new_tokens=int_tensor(self.NEW_TOKENS_SHAPE), + max_lengths_tensor=int_tensor(self.max_num_sequences), + end_ids=int_tensor(self.max_num_sequences), finish_reasons=int_tensor(self.NEW_TOKENS_SHAPE), ) @@ -1316,14 +1326,33 @@ def _process_draft_tokens_tree( return num_accepted_draft_tokens - 1 - def setup_sampler_step(self, requests: ScheduledRequests): + def _is_new_request(self, request: LlmRequest) -> bool: + return ( + not request.is_finished + and not request.py_is_draft + and ( + (request.is_context_init_state and request.is_last_context_chunk) + or request.is_disagg_generation_transmission_complete + ) + ) + + @override + def setup_sampler_step(self, scheduled_requests: ScheduledRequests): """Setup the sampler step for the requests Args: requests: list[LlmRequest]. The requests to setup the sampler step for """ if self._use_beam_search: - self._prepare_beam_search(requests.all_requests()) + self._prepare_beam_search(scheduled_requests.all_requests()) + for request in scheduled_requests.all_requests(): + if self._is_new_request(request): + self.store.max_lengths_tensor[request.py_seq_slot].fill_( + min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens) + ) + self.store.end_ids[request.py_seq_slot].fill_( + request.py_end_id if request.py_end_id is not None else -1 + ) def _prepare_beam_search( self, @@ -1335,10 +1364,7 @@ def _prepare_beam_search( initialize/reset the buffers for the request """ for request in requests: - if not request.is_finished and ( - (request.is_context_init_state and request.is_last_context_chunk) - or request.is_disagg_generation_transmission_complete - ): + if self._is_new_request(request): if request.py_return_log_probs and request.py_num_logprobs > 1: raise ValueError("Beam search does not support multiple logprobs") self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0) @@ -1848,13 +1874,9 @@ def sample_async( dtype=torch.int64, # for index_fill_ pin_memory=True, ) - # necessary for beam search - seq_lens_host = ( - torch.tensor( - [r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True - ) - if self._use_beam_search - else None + # necessary for beam search and max_length checks + seq_lens_host = torch.tensor( + [r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True ) new_tokens_host = self._process_requests( scheduled_requests, @@ -1867,12 +1889,14 @@ def sample_async( finish_reasons = self.store.finish_reasons seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) + seq_lens = seq_lens_host.to(device="cuda", non_blocking=True) first_finish_reasons = self.store.first_finish_reasons if self._use_beam_search else None self._write_finish_reasons( requests, finish_reasons=finish_reasons, seq_slots=seq_slots, + seq_lens=seq_lens, new_tokens=new_tokens, first_finish_reasons=first_finish_reasons, predecessor_beams=self.store.predecessor_beams, @@ -2443,6 +2467,7 @@ def _write_finish_reasons( *, finish_reasons: torch.Tensor, seq_slots: torch.Tensor, + seq_lens: torch.Tensor, new_tokens: torch.Tensor, first_finish_reasons: torch.Tensor | None = None, predecessor_beams: torch.Tensor | None = None, @@ -2458,7 +2483,11 @@ def _write_finish_reasons( new_tokens: a buffer containing the newly generated tokens. Shape: (max_tokens, max_batch_size, max_beam_width) """ - tokens = new_tokens[:, seq_slots.to(device=new_tokens.device, non_blocking=True)] + + # Seq Slots should be on the same device as new_tokens + assert seq_slots.device == new_tokens.device + assert seq_lens.device == new_tokens.device + tokens = new_tokens[:, seq_slots] # we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot finish_reasons.index_fill_(1, seq_slots, FinishReason.NOT_FINISHED.value) @@ -2484,12 +2513,12 @@ def _write_finish_reasons( ) batched_finish_reasons = torch.where( - self._are_max_length(requests), + self._are_max_length(seq_lens, self.store.max_lengths_tensor[seq_slots]), self._reason_tensors[FinishReason.LENGTH], batched_finish_reasons, ) batched_finish_reasons = torch.where( - self._are_end_id(requests, tokens), + self._are_end_id(self.store.end_ids[seq_slots], tokens), self._reason_tensors[FinishReason.END_ID], batched_finish_reasons, ) @@ -2505,57 +2534,29 @@ def _write_finish_reasons( ) first_finish_reasons[seq_slots] = batched_first_finish_reasons - def _are_end_id(self, requests: list[LlmRequest], tokens: torch.Tensor) -> torch.Tensor: - end_ids_tensor = ( - torch.tensor( - [ - ([req.py_end_id if req.py_end_id is not None else -1] * self.max_beam_width) - for req in requests - ] - * self.max_tokens, - pin_memory=True, - dtype=tokens.dtype, - ) - .view(self.max_tokens, len(requests), self.max_beam_width) - .to(device="cuda", non_blocking=True) - ) - return tokens == end_ids_tensor + def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor: + return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width) - def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor: + def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) -> torch.Tensor: """Checks which sequences are at or beyond the max length Args: - requests: the requests to check the max length of - + seq_lens: the sequence lengths of the requests to check the max length of + max_seq_lens: the maximum sequence lengths of the requests to check the max length of Returns: A tensor of shape (max_tokens, len(requests), max_beam_width) where each element is True if the sequence is at or beyond the max length, False otherwise """ - lengths_tensor = torch.tensor( - [ - [ - [ - (req.get_num_tokens(beam_idx) + num_tokens) - for beam_idx in range(self.max_beam_width) - ] - for req in requests - ] - for num_tokens in range(1, self.max_tokens + 1) - ] - ) - max_lengths_tensor = torch.tensor( - [ - ( - [min(req.py_max_new_tokens + req.orig_prompt_len, self.max_seq_len)] - * self.max_beam_width - ) - for req in requests - ] - * self.max_tokens - ).view(self.max_tokens, len(requests), self.max_beam_width) - return ( - (lengths_tensor >= max_lengths_tensor).pin_memory().to(device="cuda", non_blocking=True) + lengths_tensor = ( + seq_lens.view(1, -1, 1) + + torch.arange( + 1, self.max_tokens + 1, device=seq_lens.device, dtype=seq_lens.dtype + ).view(-1, 1, 1) + ).expand(self.max_tokens, -1, self.max_beam_width) + max_lengths_tensor = max_seq_lens.view(1, -1, 1).expand( + self.max_tokens, -1, self.max_beam_width ) + return lengths_tensor >= max_lengths_tensor _PAD_ID = -1 """Pad with negative, doesn't matter what""" diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 8d037275a04..64a9cc4dd16 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -222,11 +222,18 @@ class Store(TorchSampler.Store): next_draft_tokens: torch.Tensor new_tokens_lens: torch.Tensor max_total_draft_tokens: torch.Tensor - finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store + # Necessary to satisfy the interface of TorchSampler.Store + finish_reasons: None = None + end_ids: None = None + max_lengths_tensor: None = None def __post_init__(self): pass # finish_reasons has no size to compare against new_tokens in MTPSampler + def setup_sampler_step(self, scheduled_requests: ScheduledRequests): + # MTPSampler does not need to setup additional buffers before the sampler step + pass + def __init__(self, args: TorchSampler.Args, *, nextn: int): self.mapping = None self.draft_len = nextn diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index db518e60859..08f61f56d91 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -691,16 +691,40 @@ def setup(requests: list["RequestCase"]): seq_slots = torch.tensor( [req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64 ) + seq_lens = torch.tensor( + [req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda" + ) new_tokens = torch.tensor( [req.new_tokens for req in requests], dtype=torch.int32, device="cuda" ).T sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens + max_seq_lens = torch.tensor( + [ + min( + sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens + ) + for req in requests + ], + dtype=torch.int32, + device="cuda", + ) + end_ids = torch.tensor( + [ + req.request.py_end_id if req.request.py_end_id is not None else -1 + for req in requests + ], + dtype=torch.int32, + device="cuda", + ) + sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens + sampler.store.end_ids[seq_slots] = end_ids def run(): sampler._write_finish_reasons( [req.request for req in requests], finish_reasons=sampler.store.finish_reasons, new_tokens=sampler.store.new_tokens, + seq_lens=seq_lens, seq_slots=seq_slots, )