Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,12 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]:
return list(torch.split(t_squeezed, self.sequence_lengths))

@nvtx_range("ad_update_pos")
def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None:
def update_pos(
self,
seq_len: Union[torch.Tensor, List[int], int],
reset: bool = False,
update_position_ids: bool = True,
) -> None:
"""Update the starting position for each sequence in the cache.

If ``reset=True`, ``input_pos`` will be reset to zero before updating.
Expand All @@ -528,8 +533,9 @@ def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool =
else:
self.input_pos_host[:bs] += seq_len.to(self.device)
Copy link

Copilot AI Aug 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operation moves seq_len to device before adding to host tensor, which defeats the purpose of keeping calculations on host. Consider converting seq_len to CPU first: self.input_pos_host[:bs] += seq_len.cpu()

Suggested change
self.input_pos_host[:bs] += seq_len.to(self.device)
self.input_pos_host[:bs] += seq_len.cpu()

Copilot uses AI. Check for mistakes.

# update position_ids
self._update_position_ids()
# In ad_executor context, this is done later in nest_sequences, so no need to do it here
if update_position_ids:
self._update_position_ids()
self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True)

@nvtx_range("ad_assign_cache_loc")
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def _prepare_inputs(

# update the sequence info object now
si = self.cache_seq_interface.info
si.update_pos(input_pos, reset=True)
# skip calling _update_position_ids() here, as it will be called in nest_sequences
si.update_pos(input_pos, reset=True, update_position_ids=False)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it's better to not call update_pos here at all and introduce a different method that does what update_pos (update_position_ids=False) does? As-is, it is bit confusing to call update_pos without updating positions.
@galagam

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updating the position ids requires both the input positions and the sequence lengths, so it makes sense to update it whenever either is updated, but it's a bit wasteful.
A possible alternative would be to require the user to call it explicitly.
That is

si.update_input_pos()  #  rename update_pos
si.nest_sequences()
si.update_position_ids()

In any case, due to my recent changes, run time of update_position_ids decreased by x30, so it's not as critical to add this specific optimization as I initially believed. I'll run a more exhaustive check and consider to keep this optimization out of this PR for code simplicity.
@suyoggupta

si.assign_cache_loc(page_assignments)
si.nest_sequences(input_ids)

Expand Down