diff --git a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py index 0b309ae2bf8..c2081e00df8 100644 --- a/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py +++ b/tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py @@ -162,7 +162,7 @@ def forward(self, *args, **kwargs) -> Any: # copy inputs to input buffers for i, input_tensor in enumerate(args_batched): - self._input_buffers[i][: input_tensor.shape[0]] = input_tensor + self._input_buffers[i][: input_tensor.shape[0]].copy_(input_tensor, non_blocking=True) # run forward pass via graph self.graphs[combined_shape].replay() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 13c91652bff..ab1c8a263dd 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -18,6 +18,8 @@ from torch.export import Dim from torch.fx import Node +from tensorrt_llm._utils import nvtx_range + @dataclass class CacheConfig: @@ -87,11 +89,13 @@ class SequenceInfo: # Similarly, if a batch is composed of generate-only requests, # then the maximum number of sequences possible in the batch is min (max_batch_size, max_num_tokens). max_num_tokens: Optional[int] = None + # device is the device on which the sequence info is stored. + device: str = "cuda" ## [UPDATE WITH CARE] TENSOR FIELDS THAT WILL BE PASSED TO PREPARE_METADATA OP ################# # input_ids MUST ALWAYS BE THE FIRST FIELD - input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.int)) - position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, 1, dtype=torch.long)) + input_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) + position_ids: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.long)) seq_len: torch.Tensor = field(default_factory=lambda: torch.ones(1, dtype=torch.int)) input_pos: torch.Tensor = field(default_factory=lambda: torch.zeros(1, dtype=torch.int)) @@ -110,24 +114,44 @@ def __post_init__(self): # NOTE (lucaslie): WAR to address issue when using flashinfer attention with # (max_batch_size, max_seq_len) input in trtllm runtime. # see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 - max_seq_len_adjusted = self.max_seq_len + 1 + self.max_seq_len_adjusted = self.max_seq_len + 1 if self.max_num_tokens is None or self.max_num_tokens < 1: - self.max_num_tokens = self.max_batch_size * max_seq_len_adjusted + self.max_num_tokens = self.max_batch_size * self.max_seq_len_adjusted # if the provided max_num_tokens is less than the max_batch_size * max_seq_len, # we use the provided max_num_tokens to calculate the number of pages - total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted) + total_tokens = min(self.max_num_tokens, self.max_batch_size * self.max_seq_len_adjusted) # Num pages can not be less than max_batch_size. self._num_pages = max( self.max_batch_size, (total_tokens) // self.page_size + (total_tokens % self.page_size > 0), ) - self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int) - self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long) - self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int) - self.input_pos = torch.empty_like(self.seq_len) - self.cache_loc = torch.empty(self.num_pages, dtype=torch.int) - self.pages_per_seq = torch.empty_like(self.seq_len) + # Ensure that the device is set before initializing the tensors. + self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) + + # Consumers of the sequence info args require input_ids and position_ids to be truncated. + # We maintain a full version of the input_ids and position_ids to avoid overheads of tensor + # creation in every forward pass. + self.input_ids_full = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids_full = torch.zeros( + self.max_num_tokens, dtype=torch.long, device=self.device + ) + + self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int, device=self.device) + self.input_pos = torch.empty_like(self.seq_len, device=self.device) + + # Allocated host tensors for sequence lengths and input positions so that + # position_ids calculation can be done on host. + self.seq_len_host = torch.empty(self.max_batch_size, dtype=torch.int) + self.input_pos_host = torch.empty_like(self.seq_len_host) + + self.cache_loc = torch.empty(self.num_pages, dtype=torch.int, device=self.device) + self.pages_per_seq = torch.empty_like(self.seq_len, device=self.device) + + self.previous_batch_indices_cuda = torch.empty( + self.max_num_tokens, dtype=torch.long, device=self.device + ) assert self.num_pages >= self.max_batch_size, ( "num_pages must be greater than max_batch_size" ) @@ -140,13 +164,12 @@ def __post_init__(self): # indicator if extra args are activated that are needed for cached attention backends self._is_cached_attn = False + # total number of tokens in the current batch + self.num_tokens: int = 0 + # call reset once to initialize the tensors self.reset() - @property - def device(self) -> torch.device: - return self.input_pos.device - @property def args(self) -> Tuple[torch.Tensor, ...]: args = [] @@ -156,11 +179,14 @@ def args(self) -> Tuple[torch.Tensor, ...]: args.append(val) if len(args) >= self._num_uncached_attn_args and not self._is_cached_attn: break + return tuple(args) @property def _num_uncached_attn_args(self) -> int: - """Return the number of original graph arguments expected by the model.""" + """Return the number of original graph arguments expected by the model. + This is 2 because we have input_ids and position_ids as the original graph arguments. + """ return 2 @property @@ -185,7 +211,7 @@ def dynamic_shapes(self) -> Tuple[Dict[str, Dim]]: dynamic_shapes = ({}, {}) if self.max_batch_size > 1: dynamic_shapes[0][0] = Dim("batch_size", max=self.max_batch_size) - dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len) + dynamic_shapes[0][1] = Dim("seq_len", max=self.max_seq_len_adjusted) # set up shape for position_ids (same as input_ids) dynamic_shapes[1].update(dynamic_shapes[0]) # set up shape for extra args @@ -202,10 +228,6 @@ def num_sequences(self) -> int: def sequence_lengths(self) -> List[int]: return self._sequence_lengths - @property - def input_positions(self) -> List[int]: - return self.input_pos[: self.num_sequences].tolist() - @property def is_generate(self) -> bool: return all(sl == 1 for sl in self.sequence_lengths) @@ -336,12 +358,16 @@ def reset(self) -> None: self.input_pos.zero_() # set a dummy sequence corresponding to a generate-only batch (will also reset position_ids) - self.nest_sequences(torch.zeros(self.max_batch_size, 1, dtype=torch.int)) + self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) # reset cache information self.cache_loc[:] = torch.arange(self.num_pages, dtype=torch.int, device=self.device) self.pages_per_seq.fill_(1) + # let's also reset the input_ids and position_ids tensors to their max shapes (max_num_tokens) + self.input_ids = torch.ones(self.max_num_tokens, dtype=torch.int, device=self.device) + self.position_ids = torch.zeros(self.max_num_tokens, dtype=torch.long, device=self.device) + def set_example_sequence(self) -> None: """Set an example sequence useful for testing and export purposes.""" self.reset() @@ -352,7 +378,7 @@ def set_example_sequence(self) -> None: dtype=torch.int, device=self.device, ) - self.nest_sequences(input_ids) + self.nest_sequences(input_ids, allow_realloc=True) # unflatten if we are not yet using cached+flattened attention if not self._is_cached_attn: @@ -370,7 +396,7 @@ def _set_max_num_tokens_sample(self) -> None: device=self.device, ) self.pages_per_seq.fill_(seq_len // self.page_size) - self.nest_sequences(input_ids) + self.nest_sequences(input_ids, allow_realloc=True) def set_generate_only_batch(self) -> None: """Set an example sequence for generate-only batch. @@ -379,32 +405,96 @@ def set_generate_only_batch(self) -> None: mode. So we don't need to do anything mode-specific here. """ self.reset() - self.nest_sequences([[1]] * self.max_batch_size) - - def _update_position_ids(self) -> None: - # set new position_ids as new tensor from input_pos and seq_len via torch.arange - position_ids_list = [ - num - for in_pos, seq_len in zip(self.input_positions, self.sequence_lengths) - for num in range(in_pos, in_pos + seq_len) - ] - self.position_ids = torch.tensor(position_ids_list, dtype=torch.long).to(self.device) + self.nest_sequences([[1]] * self.max_batch_size, allow_realloc=True) + def maybe_reshape_for_generate(self, tensor: torch.Tensor) -> torch.Tensor: # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] if self.is_generate: - self.position_ids = self.position_ids.view(-1, 1) + return tensor.view(-1, 1, *tensor.shape[1:]) else: - self.position_ids = self.position_ids.view(1, -1) + return tensor.view(1, -1, *tensor.shape[1:]) + + @nvtx_range("ad_update_position_ids") + def _update_position_ids(self, allow_realloc: bool = False) -> None: + # set new position_ids from input_pos and seq_len + # Make sure this is done on host to avoid host-device copies. + with nvtx_range("prepare_list"): + # Optimize for the common case where all seq_len values are 1 (generation mode) + if torch.all(self.seq_len_host == 1): + # Fast path: when all seq_len are 1, position_ids is just input_pos_host + position_ids_host = ( + self.input_pos_host[: self.num_tokens].to(dtype=torch.long).pin_memory() + ) + else: + # General case - can probably be optimized too, but overall impact will be minor. + position_ids_list = [] + for in_pos, seq_len in zip(self.input_pos_host, self.seq_len_host): + position_ids_list.extend(range(in_pos, in_pos + seq_len)) + position_ids_host = torch.tensor( + position_ids_list, dtype=torch.long, pin_memory=True + ) + with nvtx_range("copy_to_device"): + if allow_realloc: + # Create a new position_ids tensor on the device + self.position_ids = position_ids_host.to(self.device).clone() + else: + self.position_ids_full = self.position_ids_full.flatten() + self.position_ids_full[: self.num_tokens].copy_( + position_ids_host, non_blocking=True + ) + with nvtx_range("maybe_reshape"): + self.position_ids = self.maybe_reshape_for_generate( + self.position_ids if allow_realloc else self.position_ids_full[: self.num_tokens] + ) - def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: + @nvtx_range("ad_update_sequence_lengths") + def _update_sequence_lengths(self, sequence_lengths: List[int]) -> None: + self._sequence_lengths = sequence_lengths + self.num_tokens = sum(self._sequence_lengths) + self.seq_len.zero_() + self.seq_len_host = torch.tensor(self._sequence_lengths, pin_memory=True) + self.seq_len[: len(self._sequence_lengths)].copy_(self.seq_len_host, non_blocking=True) + + def update_input_ids_with_new_tokens( + self, new_tokens: torch.Tensor, previous_batch_indices: List[int] + ) -> None: + """Update the input_ids with new tokens. + + This function will update the input_ids with new tokens and previous batch indices. + """ + # 1) flatten once + original_shape = self.input_ids.shape + flat = self.input_ids.flatten() + + # copy indices to the GPU + host_idx = torch.tensor(previous_batch_indices, dtype=torch.int, pin_memory=True) + idx = self.previous_batch_indices_cuda[: len(previous_batch_indices)] + idx.copy_(host_idx, non_blocking=True) + + # sort them so that masked_scatter_ lines up correctly + idx, _ = idx.sort() + + # gather the exact values you want to write + src = new_tokens[0, idx, 0] + + # in‐place fill every slot where flat == -1 with src, in order + flat.masked_scatter_(flat == -1, src) + + # 4) reshape back + self.input_ids = flat.view(original_shape) + + @nvtx_range("ad_nest_sequences") + def nest_sequences( + self, input_ids: Sequence[Sequence[int]], allow_realloc: bool = False + ) -> None: """Create and store a flattened list of input_ids from the provided list of sequences. + When allow_realloc is True, the input_ids will be reallocated on the device. This i/f will also update any relevant sequence information. """ # set new sequence lengths - seq_lens = [len(ids) for ids in input_ids] - self.seq_len.zero_() - self.seq_len[: len(seq_lens)].copy_(torch.tensor(seq_lens), non_blocking=True) + self._update_sequence_lengths([len(ids) for ids in input_ids]) + # We'll preserve the dtype of the input_ids tensor if it is a tensor, otherwise we'll use int dtype = input_ids.dtype if isinstance(input_ids, torch.Tensor) else torch.int # set new input_ids as new tensor from flattened input_ids @@ -413,49 +503,63 @@ def nest_sequences(self, input_ids: Sequence[Sequence[int]]) -> None: for lst in input_ids for val in (lst.detach().tolist() if isinstance(lst, torch.Tensor) else lst) ] - self.input_ids = torch.tensor(ids_list, dtype=dtype).to(self.device) - - # set derivative properties - self._sequence_lengths = seq_lens + input_ids_host = torch.tensor(ids_list, dtype=dtype, pin_memory=True) - # use [b,1] shape to indicate generate-only batch, otherwise use [1,total_len] - if self.is_generate: - self.input_ids = self.input_ids.view(-1, 1, *self.input_ids.shape[1:]) + if allow_realloc: + self.input_ids = input_ids_host.to(self.device).clone() else: - self.input_ids = self.input_ids.view(1, -1, *self.input_ids.shape[1:]) + self.input_ids_full = self.input_ids_full.flatten() + self.input_ids_full[: self.num_tokens].copy_(input_ids_host, non_blocking=True) + self.input_ids = self.maybe_reshape_for_generate( + self.input_ids if allow_realloc else self.input_ids_full[: self.num_tokens] + ) # update position_ids - self._update_position_ids() + self._update_position_ids(allow_realloc=allow_realloc) + @nvtx_range("ad_unnest_sequences") def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: t_squeezed = t_nested.squeeze(1) if self.is_generate else t_nested.squeeze(0) return list(torch.split(t_squeezed, self.sequence_lengths)) - def update_pos(self, seq_len: Union[torch.Tensor, List[int], int], reset: bool = False) -> None: + @nvtx_range("ad_update_pos") + def update_pos( + self, + seq_len: Union[torch.Tensor, List[int], int], + reset: bool = False, + update_position_ids: bool = True, + ) -> None: """Update the starting position for each sequence in the cache. If ``reset=True`, ``input_pos`` will be reset to zero before updating. """ if not isinstance(seq_len, torch.Tensor): - seq_len = torch.tensor(seq_len, dtype=torch.int) + seq_len = torch.tensor(seq_len, dtype=torch.int, pin_memory=True) bs = len(seq_len) if seq_len.dim() > 0 else self.max_batch_size if reset: - self.input_pos[:bs] = seq_len.to(self.device) + self.input_pos_host[:bs].copy_(seq_len, non_blocking=True) else: - self.input_pos[:bs] += seq_len.to(self.device) + self.input_pos_host[:bs] += seq_len.to(self.device) - # update position_ids - self._update_position_ids() + # In ad_executor context, this is done later in nest_sequences, so no need to do it here + if update_position_ids: + self._update_position_ids() + self.input_pos[:bs].copy_(self.input_pos_host[:bs], non_blocking=True) + @nvtx_range("ad_assign_cache_loc") def assign_cache_loc(self, page_assignments: Sequence[Sequence[int]]) -> None: """Set the cache location and pages_per_seq tensors from page assignments.""" cache_loc_flat = torch.tensor( - [p_idx for pages in page_assignments for p_idx in pages], dtype=torch.int + [p_idx for pages in page_assignments for p_idx in pages], + dtype=torch.int, + pin_memory=True, ) self.cache_loc[: len(cache_loc_flat)].copy_(cache_loc_flat, non_blocking=True) - pages_per_seq = torch.tensor([len(p) for p in page_assignments], dtype=torch.int) + pages_per_seq = torch.tensor( + [len(p) for p in page_assignments], dtype=torch.int, pin_memory=True + ) self.pages_per_seq[: len(pages_per_seq)].copy_(pages_per_seq, non_blocking=True) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 7f759d6796d..ea6570364ca 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -94,20 +94,21 @@ def build_from_config(cls, ad_config: AutoDeployConfig): f"{max_seq_len=}, {max_batch_size=}, {attn_page_size=}, {max_num_tokens=}, {max_beam_width=}" ) + # update device to contain the current default device if it's in cuda + device = torch.device(ad_config.device) + if device.type == "cuda" and device.index is None: + device = torch.device(f"cuda:{torch.cuda.current_device()}") + device = str(device) + # initialize seq info object seq_info = SequenceInfo( max_seq_len=max_seq_len, max_batch_size=max_batch_size, page_size=attn_page_size, max_num_tokens=max_num_tokens, + device=device, ) - # update device to contain the current default device if it's in cuda - device = torch.device(ad_config.device) - if device.type == "cuda" and device.index is None: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - device = str(device) - # construct inference optimizer build_and_optimize = InferenceOptimizer( factory=ad_config.create_factory(), ad_config=ad_config @@ -167,16 +168,12 @@ def _prepare_inputs( context_requests = scheduled_requests.context_requests gen_requests = [r for r in scheduled_requests.generation_requests if not r.draft_tokens] - # new_tokens is a tensor on the device, we need to convert it to a list of lists. - # can we avoid this additional gpu->cpu transfer? - new_tokens_list = new_tokens.flatten().cpu().tolist() if new_tokens is not None else None - # info to be extracted input_ids: List[List[int]] = [] input_pos: List[int] = [] last_logit_only: List[bool] = [] page_assignments: List[List[int]] = [] - + previous_batch_indices: List[int] = [] # look at context requests first for request in context_requests: # store input ids and pos of first token in sequence @@ -190,11 +187,13 @@ def _prepare_inputs( # TODO: we should also handle extend requests (for speculative decoding) here for request in gen_requests: # new_tokens are provided when the overlap scheduler is enabled. - if new_tokens_list is None or request.is_dummy or request.py_batch_idx is None: + if new_tokens is None or request.is_dummy or request.py_batch_idx is None: input_ids.append([request.get_token(0, request.get_num_tokens(0) - 1)]) input_pos.append(request.max_beam_num_tokens - 1) else: - input_ids.append([new_tokens_list[request.py_batch_idx]]) + # insert a dummy token to indicate the new tokens + input_ids.append([-1]) + previous_batch_indices.append(request.py_batch_idx) input_pos.append(request.max_beam_num_tokens) request.py_batch_idx = request.seq_slot @@ -210,11 +209,16 @@ def _prepare_inputs( # update the sequence info object now si = self.cache_seq_interface.info - si.nest_sequences(input_ids) - si.update_pos(input_pos, reset=True) + # skip calling _update_position_ids() here, as it will be called in nest_sequences + si.update_pos(input_pos, reset=True, update_position_ids=False) si.assign_cache_loc(page_assignments) + si.nest_sequences(input_ids) + + if new_tokens is not None: + si.update_input_ids_with_new_tokens(new_tokens, previous_batch_indices) return last_logit_only + @nvtx_range("ad_compute_logits") def _compute_logits(self) -> List[torch.Tensor]: # run the model logits: torch.Tensor = self.model(*self.cache_seq_interface.args)[0] @@ -231,13 +235,13 @@ def forward( self, scheduled_requests: ScheduledRequests, resource_manager: ResourceManager, - new_tokens_device: Optional[torch.Tensor] = None, + new_tensors_device: Optional[torch.Tensor] = None, gather_context_logits: bool = False, cache_indirection_buffer: Optional[torch.Tensor] = None, ): """Run forward from scheduled requests; main entrypoint that gets called by the executor.""" # convert requests and store in sequence info object - new_tokens = getattr(new_tokens_device, "new_tokens", None) + new_tokens = getattr(new_tensors_device, "new_tokens", None) last_logit_only = self._prepare_inputs(scheduled_requests, resource_manager, new_tokens) # compute all logits