diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 13c91652bff..f25a6e53d33 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -17,7 +17,7 @@ from torch._ops import OpOverloadPacket from torch.export import Dim from torch.fx import Node - +from tensorrt_llm._utils import nvtx_range @dataclass class CacheConfig: @@ -122,7 +122,7 @@ def __post_init__(self): self.max_batch_size, (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), ) - self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) + self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) self.input_pos = torch.empty_like(self.seq_len) @@ -336,7 +336,7 @@ def reset(self) -> None: self.input_pos.zero_() # set a dummy sequence corresponding to a generate-only batch (will also reset position_ids) - self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int)) + self.nest_sequences([[1]] * self.max_batch_size) # reset cache information self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) @@ -381,6 +381,7 @@ def set_generate_only_batch(self) -> None: self.reset() self.nest_sequences([[1]] * self.max_batch_size) + @nvtx_range("ad_update_position_ids") def _update_position_ids(self) -> None: # set new position_ids as new tensor from input_pos and seq_len via torch.arange position_ids_list = [ @@ -388,7 +389,7 @@ def _update_position_ids(self) -> None: for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths) for num in range(in_pos, in_pos + seq_len) ] - self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device) + self.position_ids = torch.tensor(position_ids_list, dtype=torch.long, pin_memory=True).to(self.device) # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] if self.is_generate: @@ -396,7 +397,8 @@ def _update_position_ids(self) -> None: else: self.position_ids = self.position_ids.view(1, -1) - def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: + @nvtx_range("ad_nest_sequences") + def nest_sequences(self, input_ids: Sequence[Sequence[int]], previous_batch_indices: List[int] = [], new_tokens: Optional[torch.Tensor] = None) -> None: """Create and store a flattened list of input_ids from the provided list of sequences. This i/f will also update any relevant sequence information. @@ -413,8 +415,10 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: for lst in input_ids for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) ] - self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device) - + self.input_ids = torch.tensor(ids_list, dtype=dtype, pin_memory=True).to(self.device) + if new_tokens is not None: + self.input_ids[self.input_ids == -1] = new_tokens[0,previous_batch_indices,0] + # set derivative properties self._sequence_lengths = seq_lens @@ -431,6 +435,7 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.sequence_lengths)) + @nvtx_range("ad_update_pos") def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None: """Update the starting position for each sequence in the cache. @@ -448,10 +453,11 @@ def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = # update position_ids self._update_position_ids() + @nvtx_range("ad_assign_cache_loc") def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None: """Set the cache location and pages_per_seq tensors from page assignments.""" cache_loc_flat = torch.tensor( - [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int + [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int, pin_memory=True ) self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7f759d6796d..748163baeea 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -101,7 +101,7 @@ def build_from_config(cls, ad_config: AutoDeployConfig): page_size=attn_page_size, max_num_tokens=max_num_tokens, ) - + print(" in seq_info for device: ", torch.cuda.current_device()) # update device to contain the current default device if it's in cuda device = torch.device(ad_config.device) if device.type == "cuda" and device.index is None: @@ -167,16 +167,12 @@ def _prepare_inputs( context_requests = scheduled_requests.context_requests gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens] - # new_tokens is a tensor on the device, we need to convert it to a list of lists. - # can we avoid this additional gpu->cpu transfer? - new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None - # info to be extracted input_ids: List[List[int]] = [] input_pos: List[int] = [] last_logit_only: List[bool] = [] page_assignments: List[List[int]] = [] - + previous_batch_indices: List[int] = [] # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence @@ -190,11 +186,13 @@ def _prepare_inputs( # TODO: we should also handle extend requests (for speculative decoding) here for request in gen_requests: # new_tokens are provided when the overlap scheduler is enabled. - if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None: + if new_tokens is None or request.is_dummy or request.py_batch_idx is None: input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)]) input_pos.append(request.max_beam_num_tokens - 1) else: - input_ids.append([new_tokens_list[request.py_batch_idx]]) + # insert a dummy token to indicate the new tokens + input_ids.append([-1]) + previous_batch_indices.append(request.py_batch_idx) input_pos.append(request.max_beam_num_tokens) request.py_batch_idx = request.seq_slot @@ -207,10 +205,9 @@ def _prepare_inputs( # get cache indices cache_indices = kv_cache_manager.get_cache_indices(request) page_assignments.append(cache_indices) - # update the sequence info object now si = self.cache_seq_interface.info - si.nest_sequences(input_ids) + si.nest_sequences(input_ids, previous_batch_indices, new_tokens) si.update_pos(input_pos, reset=True) si.assign_cache_loc(page_assignments) return last_logit_only