Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 59 additions & 58 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,12 @@ def is_generation_model(self) -> bool:
class Store:
new_tokens: torch.Tensor
"""Shape: See cpp DecoderState.getAllNewTokens()"""
max_lengths_tensor: torch.Tensor
"""Shape: batch_size
Usage: Stores the maximum lengths for each request"""
end_ids: torch.Tensor
"""Shape: batch_size
Usage: Stores the end ids for each request"""
finish_reasons: torch.Tensor
"""Shape: max_tokens, batch_size, beam_width
Usage: Stores the currently estimated finish_reasons for each request"""
Expand Down Expand Up @@ -881,10 +887,14 @@ def _create_store(self) -> Store:
first_finish_reasons=int_tensor(
self.CACHE_INDIRECTION_SHAPE[:-1],
),
max_lengths_tensor=int_tensor(self.max_num_sequences),
end_ids=int_tensor(self.max_num_sequences),
)
else:
return self.Store(
new_tokens=int_tensor(self.NEW_TOKENS_SHAPE),
max_lengths_tensor=int_tensor(self.max_num_sequences),
end_ids=int_tensor(self.max_num_sequences),
finish_reasons=int_tensor(self.NEW_TOKENS_SHAPE),
)

Expand Down Expand Up @@ -1316,14 +1326,33 @@ def _process_draft_tokens_tree(

return num_accepted_draft_tokens - 1

def setup_sampler_step(self, requests: ScheduledRequests):
def _is_new_request(self, request: LlmRequest) -> bool:
return (
not request.is_finished
and not request.py_is_draft
and (
(request.is_context_init_state and request.is_last_context_chunk)
or request.is_disagg_generation_transmission_complete
)
)

@override
def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
"""Setup the sampler step for the requests

Args:
requests: list[LlmRequest]. The requests to setup the sampler step for
"""
if self._use_beam_search:
self._prepare_beam_search(requests.all_requests())
self._prepare_beam_search(scheduled_requests.all_requests())
for request in scheduled_requests.all_requests():
if self._is_new_request(request):
self.store.max_lengths_tensor[request.py_seq_slot].fill_(
min(self.max_seq_len, request.orig_prompt_len + request.py_max_new_tokens)
)
self.store.end_ids[request.py_seq_slot].fill_(
request.py_end_id if request.py_end_id is not None else -1
)

def _prepare_beam_search(
self,
Expand All @@ -1335,10 +1364,7 @@ def _prepare_beam_search(
initialize/reset the buffers for the request
"""
for request in requests:
if not request.is_finished and (
(request.is_context_init_state and request.is_last_context_chunk)
or request.is_disagg_generation_transmission_complete
):
if self._is_new_request(request):
if request.py_return_log_probs and request.py_num_logprobs > 1:
raise ValueError("Beam search does not support multiple logprobs")
self.store.cache_indirection[request.py_seq_slot, :, request.py_prompt_len].fill_(0)
Expand Down Expand Up @@ -1848,13 +1874,9 @@ def sample_async(
dtype=torch.int64, # for index_fill_
pin_memory=True,
)
# necessary for beam search
seq_lens_host = (
torch.tensor(
[r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True
)
if self._use_beam_search
else None
# necessary for beam search and max_length checks
seq_lens_host = torch.tensor(
[r.max_beam_num_tokens for r in requests], dtype=torch.int32, pin_memory=True
)
new_tokens_host = self._process_requests(
scheduled_requests,
Expand All @@ -1867,12 +1889,14 @@ def sample_async(

finish_reasons = self.store.finish_reasons
seq_slots = seq_slots_host.to(device="cuda", non_blocking=True)
seq_lens = seq_lens_host.to(device="cuda", non_blocking=True)
first_finish_reasons = self.store.first_finish_reasons if self._use_beam_search else None

self._write_finish_reasons(
requests,
finish_reasons=finish_reasons,
seq_slots=seq_slots,
seq_lens=seq_lens,
new_tokens=new_tokens,
first_finish_reasons=first_finish_reasons,
predecessor_beams=self.store.predecessor_beams,
Expand Down Expand Up @@ -2443,6 +2467,7 @@ def _write_finish_reasons(
*,
finish_reasons: torch.Tensor,
seq_slots: torch.Tensor,
seq_lens: torch.Tensor,
new_tokens: torch.Tensor,
first_finish_reasons: torch.Tensor | None = None,
predecessor_beams: torch.Tensor | None = None,
Expand All @@ -2458,7 +2483,11 @@ def _write_finish_reasons(
new_tokens: a buffer containing the newly generated tokens.
Shape: (max_tokens, max_batch_size, max_beam_width)
"""
tokens = new_tokens[:, seq_slots.to(device=new_tokens.device, non_blocking=True)]

# Seq Slots should be on the same device as new_tokens
assert seq_slots.device == new_tokens.device
assert seq_lens.device == new_tokens.device
tokens = new_tokens[:, seq_slots]

# we need to fill with NOT_FINISHED so we can differentiate between previous requests that had the same seq slot
finish_reasons.index_fill_(1, seq_slots, FinishReason.NOT_FINISHED.value)
Expand All @@ -2484,12 +2513,12 @@ def _write_finish_reasons(
)

batched_finish_reasons = torch.where(
self._are_max_length(requests),
self._are_max_length(seq_lens, self.store.max_lengths_tensor[seq_slots]),
self._reason_tensors[FinishReason.LENGTH],
batched_finish_reasons,
)
batched_finish_reasons = torch.where(
self._are_end_id(requests, tokens),
self._are_end_id(self.store.end_ids[seq_slots], tokens),
self._reason_tensors[FinishReason.END_ID],
batched_finish_reasons,
)
Expand All @@ -2505,57 +2534,29 @@ def _write_finish_reasons(
)
first_finish_reasons[seq_slots] = batched_first_finish_reasons

def _are_end_id(self, requests: list[LlmRequest], tokens: torch.Tensor) -> torch.Tensor:
end_ids_tensor = (
torch.tensor(
[
([req.py_end_id if req.py_end_id is not None else -1] * self.max_beam_width)
for req in requests
]
* self.max_tokens,
pin_memory=True,
dtype=tokens.dtype,
)
.view(self.max_tokens, len(requests), self.max_beam_width)
.to(device="cuda", non_blocking=True)
)
return tokens == end_ids_tensor
def _are_end_id(self, end_ids: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
return tokens == end_ids.view(1, -1, 1).expand(self.max_tokens, -1, self.max_beam_width)

def _are_max_length(self, requests: list[LlmRequest]) -> torch.Tensor:
def _are_max_length(self, seq_lens: torch.Tensor, max_seq_lens: torch.Tensor) -> torch.Tensor:
"""Checks which sequences are at or beyond the max length

Args:
requests: the requests to check the max length of

seq_lens: the sequence lengths of the requests to check the max length of
max_seq_lens: the maximum sequence lengths of the requests to check the max length of
Returns:
A tensor of shape (max_tokens, len(requests), max_beam_width)
where each element is True if the sequence is at or beyond the max length, False otherwise
"""
lengths_tensor = torch.tensor(
[
[
[
(req.get_num_tokens(beam_idx) + num_tokens)
for beam_idx in range(self.max_beam_width)
]
for req in requests
]
for num_tokens in range(1, self.max_tokens + 1)
]
)
max_lengths_tensor = torch.tensor(
[
(
[min(req.py_max_new_tokens + req.orig_prompt_len, self.max_seq_len)]
* self.max_beam_width
)
for req in requests
]
* self.max_tokens
).view(self.max_tokens, len(requests), self.max_beam_width)
return (
(lengths_tensor >= max_lengths_tensor).pin_memory().to(device="cuda", non_blocking=True)
lengths_tensor = (
seq_lens.view(1, -1, 1)
+ torch.arange(
1, self.max_tokens + 1, device=seq_lens.device, dtype=seq_lens.dtype
).view(-1, 1, 1)
).expand(self.max_tokens, -1, self.max_beam_width)
max_lengths_tensor = max_seq_lens.view(1, -1, 1).expand(
self.max_tokens, -1, self.max_beam_width
)
return lengths_tensor >= max_lengths_tensor

_PAD_ID = -1
"""Pad with negative, doesn't matter what"""
Expand Down
9 changes: 8 additions & 1 deletion tensorrt_llm/_torch/speculative/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,18 @@ class Store(TorchSampler.Store):
next_draft_tokens: torch.Tensor
new_tokens_lens: torch.Tensor
max_total_draft_tokens: torch.Tensor
finish_reasons: None = None # Necessary to satisfy the interface of TorchSampler.Store
# Necessary to satisfy the interface of TorchSampler.Store
finish_reasons: None = None
end_ids: None = None
max_lengths_tensor: None = None

def __post_init__(self):
pass # finish_reasons has no size to compare against new_tokens in MTPSampler

def setup_sampler_step(self, scheduled_requests: ScheduledRequests):
# MTPSampler does not need to setup additional buffers before the sampler step
pass

def __init__(self, args: TorchSampler.Args, *, nextn: int):
self.mapping = None
self.draft_len = nextn
Expand Down
24 changes: 24 additions & 0 deletions tests/unittest/_torch/sampler/test_torch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,16 +691,40 @@ def setup(requests: list["RequestCase"]):
seq_slots = torch.tensor(
[req.request.py_seq_slot for req in requests], device="cuda", dtype=torch.int64
)
seq_lens = torch.tensor(
[req.request.max_beam_num_tokens for req in requests], dtype=torch.int32, device="cuda"
)
new_tokens = torch.tensor(
[req.new_tokens for req in requests], dtype=torch.int32, device="cuda"
).T
sampler.store.new_tokens[:, seq_slots, BEAM] = new_tokens
max_seq_lens = torch.tensor(
[
min(
sampler.max_seq_len, req.request.orig_prompt_len + req.request.py_max_new_tokens
)
for req in requests
],
dtype=torch.int32,
device="cuda",
)
end_ids = torch.tensor(
[
req.request.py_end_id if req.request.py_end_id is not None else -1
for req in requests
],
dtype=torch.int32,
device="cuda",
)
sampler.store.max_lengths_tensor[seq_slots] = max_seq_lens
sampler.store.end_ids[seq_slots] = end_ids

def run():
sampler._write_finish_reasons(
[req.request for req in requests],
finish_reasons=sampler.store.finish_reasons,
new_tokens=sampler.store.new_tokens,
seq_lens=seq_lens,
seq_slots=seq_slots,
)

Expand Down
Loading