Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch._ops import OpOverloadPacket
from torch.export import Dim
from torch.fx import Node

from tensorrt_llm._utils import nvtx_range

@dataclass
class CacheConfig:
Expand Down Expand Up @@ -122,7 +122,7 @@ def __post_init__(self):
self.max_batch_size,
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
)
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
self.input_pos = torch.empty_like(self.seq_len)
Comment on lines +125 to 128
Copy link

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tensor is created with device=self.device, but self.device is a property that depends on self.input_pos.device. At this point in post_init, input_pos hasn't been initialized yet, which could cause an AttributeError.

Suggested change
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
self.input_pos = torch.empty_like(self.seq_len)
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
self.input_pos = torch.empty_like(self.seq_len)
self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device)
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)

Copilot uses AI. Check for mistakes.
Expand Down Expand Up @@ -336,7 +336,7 @@ def reset(self) -> None:
self.input_pos.zero_()

# set a dummy sequence corresponding to a generate-only batch (will also reset position_ids)
self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int))
self.nest_sequences([[1]] * self.max_batch_size)

# reset cache information
self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device)
Expand Down Expand Up @@ -381,22 +381,24 @@ def set_generate_only_batch(self) -> None:
self.reset()
self.nest_sequences([[1]] * self.max_batch_size)

@nvtx_range("ad_update_position_ids")
def _update_position_ids(self) -> None:
# set new position_ids as new tensor from input_pos and seq_len via torch.arange
position_ids_list = [
num
for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths)
for num in range(in_pos, in_pos + seq_len)
]
self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device)
self.position_ids = torch.tensor(position_ids_list, dtype=torch.long, pin_memory=True).to(self.device)

# use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len]
if self.is_generate:
self.position_ids = self.position_ids.view(-1, 1)
else:
self.position_ids = self.position_ids.view(1, -1)

def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
@nvtx_range("ad_nest_sequences")
def nest_sequences(self, input_ids: Sequence[Sequence[int]], previous_batch_indices: List[int] = [], new_tokens: Optional[torch.Tensor] = None) -> None:
"""Create and store a flattened list of input_ids from the provided list of sequences.

This i/f will also update any relevant sequence information.
Expand All @@ -413,8 +415,10 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None:
for lst in input_ids
for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst)
]
self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device)

self.input_ids = torch.tensor(ids_list, dtype=dtype, pin_memory=True).to(self.device)
if new_tokens is not None:
self.input_ids[self.input_ids == -1] = new_tokens[0,previous_batch_indices,0]

Comment on lines +420 to +421
Copy link

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This indexing assumes new_tokens has at least 3 dimensions and previous_batch_indices is valid, but there's no bounds checking. If previous_batch_indices contains invalid indices or new_tokens has different dimensions, this will cause a runtime error.

Suggested change
self.input_ids[self.input_ids == -1] = new_tokens[0,previous_batch_indices,0]
# Validate new_tokens dimensions
if new_tokens.dim() < 3:
raise ValueError(f"new_tokens must have at least 3 dimensions, but got {new_tokens.dim()}.")
# Validate previous_batch_indices
max_index = new_tokens.size(1) - 1
if any(idx < 0 or idx > max_index for idx in previous_batch_indices):
raise IndexError(f"previous_batch_indices contains out-of-bounds indices for new_tokens' second dimension (valid range: 0 to {max_index}).")
self.input_ids[self.input_ids == -1] = new_tokens[0, previous_batch_indices, 0]

Copilot uses AI. Check for mistakes.
# set derivative properties
self._sequence_lengths = seq_lens

Expand All @@ -431,6 +435,7 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0)
return list(torch.split(t_squeezed, self.sequence_lengths))

@nvtx_range("ad_update_pos")
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
"""Update the starting position for each sequence in the cache.

Expand All @@ -448,10 +453,11 @@ def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool =
# update position_ids
self._update_position_ids()

@nvtx_range("ad_assign_cache_loc")
def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None:
"""Set the cache location and pages_per_seq tensors from page assignments."""
cache_loc_flat = torch.tensor(
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int
[p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int, pin_memory=True
)
self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True)

Expand Down
17 changes: 7 additions & 10 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def build_from_config(cls, ad_config: AutoDeployConfig):
page_size=attn_page_size,
max_num_tokens=max_num_tokens,
)

print(" in seq_info for device: ", torch.cuda.current_device())
Copy link

Copilot AI Jul 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug print statement should be removed before merging to production. This appears to be leftover debugging code.

Suggested change
print(" in seq_info for device: ", torch.cuda.current_device())
ad_logger.info(f"in seq_info for device: {torch.cuda.current_device()}")

Copilot uses AI. Check for mistakes.
# update device to contain the current default device if it's in cuda
device = torch.device(ad_config.device)
if device.type == "cuda" and device.index is None:
Expand Down Expand Up @@ -167,16 +167,12 @@ def _prepare_inputs(
context_requests = scheduled_requests.context_requests
gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens]

# new_tokens is a tensor on the device, we need to convert it to a list of lists.
# can we avoid this additional gpu->cpu transfer?
new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None

# info to be extracted
input_ids: List[List[int]] = []
input_pos: List[int] = []
last_logit_only: List[bool] = []
page_assignments: List[List[int]] = []

previous_batch_indices: List[int] = []
# look at context requests first
for request in context_requests:
# store input ids and pos of first token in sequence
Expand All @@ -190,11 +186,13 @@ def _prepare_inputs(
# TODO: we should also handle extend requests (for speculative decoding) here
for request in gen_requests:
# new_tokens are provided when the overlap scheduler is enabled.
if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None:
if new_tokens is None or request.is_dummy or request.py_batch_idx is None:
input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)])
input_pos.append(request.max_beam_num_tokens - 1)
else:
input_ids.append([new_tokens_list[request.py_batch_idx]])
# insert a dummy token to indicate the new tokens
input_ids.append([-1])
previous_batch_indices.append(request.py_batch_idx)
input_pos.append(request.max_beam_num_tokens)

request.py_batch_idx = request.seq_slot
Expand All @@ -207,10 +205,9 @@ def _prepare_inputs(
# get cache indices
cache_indices = kv_cache_manager.get_cache_indices(request)
page_assignments.append(cache_indices)

# update the sequence info object now
si = self.cache_seq_interface.info
si.nest_sequences(input_ids)
si.nest_sequences(input_ids, previous_batch_indices, new_tokens)
si.update_pos(input_pos, reset=True)
si.assign_cache_loc(page_assignments)
return last_logit_only
Expand Down
Loading