From cde34ad295d7d0a55eb7044c6bbfe16742dd5233 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 23 Dec 2025 08:44:01 -0800 Subject: [PATCH 1/4] attention i/f providing device and host argument Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 82 +++---------------- .../custom_ops/fla/fla_backend_delta.py | 8 +- .../custom_ops/flashinfer_attention.py | 14 ++-- .../mamba/cuda_backend_causal_conv.py | 8 +- .../mamba/torch_backend_causal_conv.py | 8 +- .../custom_ops/mamba/torch_backend_mamba.py | 8 +- .../custom_ops/mamba/triton_backend_mamba.py | 14 ++-- .../_torch/auto_deploy/custom_ops/mla.py | 8 +- .../custom_ops/torch_backend_attention.py | 8 +- .../custom_ops/torch_gather_logits.py | 8 +- .../custom_ops/triton_attention.py | 8 +- .../auto_deploy/models/patches/bamba.py | 10 +-- .../library/gather_logits_before_lm_head.py | 6 +- .../defs/accuracy/test_llm_api_autodeploy.py | 5 +- .../_utils_test/torch_attention_reference.py | 22 ++--- .../singlegpu/custom_ops/test_attention_op.py | 8 +- .../test_cuda_causal_conv_cached_op.py | 6 +- .../test_flashinfer_attention_op.py | 48 +++++------ .../custom_ops/test_torch_attention_op.py | 10 ++- .../test_torch_causal_conv_cached_op.py | 14 ++-- .../custom_ops/test_torch_mamba_cached_op.py | 14 ++-- .../custom_ops/test_triton_mamba_cached_op.py | 8 +- .../test_gather_logits_before_lm_head.py | 28 +++---- 23 files changed, 149 insertions(+), 204 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index a724b2e7bcb..750c9085674 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -508,6 +508,9 @@ def __init__( # Create the InputBuffer that manages contiguous host and device memory # Starts on default device; use to() to move to target device self._input_buffer = InputBuffer(tensor_specs) + self._available_args = set(self._input_buffer.tensor_names) | { + f"{name}_host" for name in self._input_buffer.tensor_names + } # Initialize args_list from tensor specs self._args_list: Dict[str, List[int]] = { @@ -515,9 +518,7 @@ def __init__( } self._active_args = ("input_ids", "position_ids") - self._shapeable_args = ("input_ids", "position_ids") - # Args that should be returned from host (pinned memory) instead of device in _named_args - self._host_return_args = ("batch_info", "logits_gather_info") + self._shapeable_args = ("input_ids", "position_ids", "input_ids_host", "position_ids_host") ############################################################################################ # EXTRA TENSOR FIELDS ###################################################################### @@ -558,14 +559,13 @@ def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor: def _get_arg(self, name: str) -> torch.Tensor: """Get the argument from the input buffer either on device or host.""" - if name in self._host_return_args: - arg = self._input_buffer.get_host_view(name) + if name.endswith("_host"): + arg = self._input_buffer.get_host_view(name.replace("_host", "")) else: arg = self._input_buffer.get_view(name) return self._shape_for_forward(arg) if name in self._shapeable_args else arg def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]: - # Build args dict, using host views for _host_return_args, device views otherwise args = {k: self._get_arg(k) for k in self._active_args} # check other args to include @@ -577,7 +577,7 @@ def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor @property def available_args(self) -> Set[str]: """Return a list of available arguments.""" - return set(self._input_buffer.tensor_names) + return self._available_args @property def named_args(self) -> Dict[str, torch.Tensor]: @@ -697,68 +697,6 @@ def _get_cache_locations_and_pages_per_sequence( pages_per_seq = [len(p) for p in page_assignments] return cache_loc_flat, pages_per_seq - # TODO: remove after updating all cached backends - @classmethod - def _get_sanitized_seq_len( - cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor - ) -> torch.Tensor: - """Sanitize sequence lengths. - - We want to cover the following scenarios with this function: - - 1. Pre-fill: - input_ids: [1, s_total, ...] - seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0] - ---> returns [s_0, s_1, ..., s_{b-1}] - 2. Decode: - input_ids: [b, 1, ...] - seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0] - |---- b ----|--- (max_batch_size - b) ---| - --> returns [1,] * b - 3. Decode in Cudagraph: - input_ids: [b_cudagraph, 1, ...] - seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0] - |---- b ----|--- (max_batch_size - b) ---| - - --> returns [1,] * b_cudagraph - Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to - b_cudagraph. - - # TODO: I could see one possible issue with this approach in the future. - # If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location - # information. What could happen is that the for the padded sequences the cache location - # tensors point to allocated pages. This could lead to a situation where we write into - # allocated cache pages polluting the cache of other sequences. Now this is not an issue - # if we write the dummy sequences into unallocated cache pages... One fix could be to - # pad not only the seq len but also pad the cache locations by just repeating the last - # valid cache location in the batch. This would ensure that the dummy sequences just - # repeats valid computation... - """ - _, s = input_or_position_ids.shape[:2] - num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len) - if s > 1: - return seq_len[:num_seq].clone() - else: - return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device) - - @staticmethod - def _get_sanitized_num_sequences( - input_or_position_ids: torch.Tensor, seq_len: torch.Tensor - ) -> int: - """Get number of sequences. - - We makes sure that this function is compatible with both torch graph capture and cudagraph. - Both can be a bit temparamental when trying to extract the number of sequences from a tensor - with max_batch_size or max_batch_size*max_seq_len. - """ - b, s = input_or_position_ids.shape[:2] - if s > 1: - num_seq = torch.sum(seq_len > 0) - assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded" - else: - num_seq = b - return num_seq - def activate_arg(self, arg_name: str) -> bool: """Activate a desired argument. @@ -869,7 +807,7 @@ def _store_arg( self._args_list[name] = tnsr_like.copy() # Only store to buffer when the argument is active or force_copy is True - if not (name in self._active_args or force_copy): + if not (name in self._active_args or f"{name}_host" in self._active_args or force_copy): return # Store to the InputBuffer's pinned host memory @@ -1090,12 +1028,12 @@ def rescatter_input_ids(self, ungathered_input_ids: torch.Tensor): def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor: """Maybe gather the logits if logits have not been gathered yet.""" num_tokens = logits.shape[0] * logits.shape[1] - num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist() + num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info_host").tolist() if gather_required and num_tokens_to_gather < num_tokens: logits = torch.ops.auto_deploy.gather_logits_before_lm_head( logits, self._get_arg("logits_gather_indices"), - self._get_arg("logits_gather_info"), + self._get_arg("logits_gather_info_host"), ) return logits.squeeze(int(self.is_generate)) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py index 5cf4a4149ce..757aff042f2 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py @@ -35,7 +35,7 @@ def fla_cached_delta_rule( v: torch.Tensor, beta: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, @@ -58,7 +58,7 @@ def fla_cached_delta_rule( y = torch.empty_like(v, memory_format=torch.contiguous_format) y_flat = y.view(b * s, num_heads, -1) - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode # clean up metadata @@ -120,7 +120,7 @@ def fla_cached_delta_rule_fake( v: torch.Tensor, beta: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, @@ -160,7 +160,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"] + return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 2410ee27ea8..b8281c2caca 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -359,7 +359,7 @@ def _plan_decode( @torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=()) def prepare_flashinfer_metadata( position_ids: torch.Tensor, - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, seq_len_with_cache: torch.Tensor, ) -> List[torch.Tensor]: @@ -370,7 +370,7 @@ def prepare_flashinfer_metadata( to understand the convention. """ # retrieve host-side metadata - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode num_tokens = num_prefill_tokens + num_decode @@ -393,7 +393,7 @@ def prepare_flashinfer_metadata( @prepare_flashinfer_metadata.register_fake def prepare_flashinfer_metadata_fake( position_ids: torch.Tensor, - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, seq_len_with_cache: torch.Tensor, ): @@ -411,7 +411,7 @@ def flashinfer_mha_with_cache( k: torch.Tensor, v: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, cu_num_pages: torch.Tensor, cache_loc: torch.Tensor, @@ -439,7 +439,7 @@ def flashinfer_mha_with_cache( v = v.reshape(b * s, -1, head_dim) # convert to flashinfer-style metadata - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode qo_indptr = cu_seqlen[: num_seq + 1] @@ -506,7 +506,7 @@ def flashinfer_mha_with_cache_fake( k: torch.Tensor, v: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, cu_num_pages: torch.Tensor, cache_loc: torch.Tensor, @@ -559,7 +559,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"] + return ["batch_info_host", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"] @classmethod def get_prepare_extra_metadata_info( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py index 29f62814c4b..bc7752df521 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py @@ -53,7 +53,7 @@ def _cuda_cached_causal_conv1d( weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k] bias: Optional[torch.Tensor], # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, @@ -80,7 +80,7 @@ def _cuda_cached_causal_conv1d( """ b, s = input.shape[:2] - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode num_total_tokens = num_prefill_tokens + num_decode @@ -138,7 +138,7 @@ def _cuda_cached_causal_conv1d_fake( weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k] bias: Optional[torch.Tensor], # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, @@ -189,7 +189,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"] + return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py index b055f22dedc..fd5c7170230 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py @@ -147,7 +147,7 @@ def _torch_cached_causal_conv1d( weight: torch.Tensor, # [c_out, c_in/groups, k] bias: Optional[torch.Tensor], # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, @@ -174,7 +174,7 @@ def _torch_cached_causal_conv1d( num_seq = seq_len.shape[0] # get cleaned up metadata - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode seq_len = seq_len[:num_seq] seq_start = cu_seqlen[:num_seq] @@ -247,7 +247,7 @@ def _torch_cached_causal_conv1d_fake( weight: torch.Tensor, # [c_out, c_in/groups, k] bias: Optional[torch.Tensor], # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, @@ -296,7 +296,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"] + return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py index e9518050133..03f403e0f69 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py @@ -121,7 +121,7 @@ def _torch_cached_ssm( dt: torch.Tensor, # [b, s, num_heads] dt_bias: torch.Tensor, # [num_heads] # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, @@ -145,7 +145,7 @@ def _torch_cached_ssm( num_seq = seq_len.shape[0] # get cleaned up metadata - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode seq_len = seq_len[:num_seq] seq_start = cu_seqlen[:num_seq] @@ -246,7 +246,7 @@ def _torch_cached_ssm_fake( dt: torch.Tensor, # [b, s, num_heads] dt_bias: torch.Tensor, # [num_heads] # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, @@ -293,7 +293,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"] + return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index d3ea70221ba..d119ddeece5 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -42,7 +42,7 @@ def _triton_ssm_prepare_metadata( # INPUTS position_ids: torch.Tensor, - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, cu_seqlen: torch.Tensor, # EXTRA METADATA PROVIDED BY THE DESCRIPTOR @@ -53,7 +53,7 @@ def _triton_ssm_prepare_metadata( Returns a tuple of (seq_len_sanitized, seq_start, slot_idx_sanitized). """ device = cu_seqlen.device - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() if num_prefill > 0: chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( @@ -74,7 +74,7 @@ def _triton_ssm_prepare_metadata( def _triton_ssm_prepare_metadata_fake( # INPUTS position_ids: torch.Tensor, - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, cu_seqlen: torch.Tensor, # EXTRA METADATA PROVIDED BY THE DESCRIPTOR @@ -110,7 +110,7 @@ def _triton_cached_ssm( dt: torch.Tensor, # [b, s, num_heads] dt_bias: torch.Tensor, # [num_heads] # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, @@ -140,7 +140,7 @@ def _triton_cached_ssm( ssm_state_size = B.shape[3] - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode num_total_tokens = num_prefill_tokens + num_decode @@ -245,7 +245,7 @@ def _triton_cached_ssm_fake( dt: torch.Tensor, # [b, s, num_heads] dt_bias: torch.Tensor, # [num_heads] # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, @@ -294,7 +294,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"] + return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] @classmethod def get_prepare_extra_metadata_info( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py index 05212151002..716cda7d1be 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mla.py @@ -31,7 +31,7 @@ def fused_flattened_mla_with_cache( kv: torch.Tensor, k_pe: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, cache_loc: torch.Tensor, @@ -55,7 +55,7 @@ def fused_flattened_mla_with_cache( # and number of tokens per sequence are encoded in seq_len and seq_start. # check for sequence info and truncate metadata - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode seq_len = seq_len[:num_seq] @@ -166,7 +166,7 @@ def fused_flattened_mla_with_cache_fake( kv: torch.Tensor, k_pe: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, cache_loc: torch.Tensor, @@ -212,7 +212,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"] + return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index cab0a0302b0..09bc253708b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -253,7 +253,7 @@ def torch_backend_mha_with_cache( k: torch.Tensor, v: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, cache_loc: torch.Tensor, @@ -278,7 +278,7 @@ def torch_backend_mha_with_cache( b, s = q.shape[:2] # get cleaned up metadata - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode seq_len = seq_len[:num_seq] input_pos = input_pos[:num_seq] @@ -352,7 +352,7 @@ def torch_backend_mha_with_cache_fake( k: torch.Tensor, v: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, cache_loc: torch.Tensor, @@ -400,7 +400,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"] + return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_gather_logits.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_gather_logits.py index 355a1fbccd2..7669ea8966e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_gather_logits.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_gather_logits.py @@ -5,14 +5,14 @@ def gather_logits_before_lm_head( hidden_states: torch.Tensor, logits_gather_indices: torch.Tensor, # long tensor - logits_gather_info: torch.Tensor, # int tensor + logits_gather_info_host: torch.Tensor, # int tensor ) -> torch.Tensor: """Gather hidden states using logits_gather_indices before LM head. Args: hidden_states: Hidden states tensor [b, 1, hidden] or [1, s_total, hidden] logits_gather_indices: indices for gathering logits. - logits_gather_info: info for gathering logits. + logits_gather_info_host: info for gathering logits. Returns: Gathered and flattened hidden states [num_gathered_tokens, hidden] """ @@ -21,7 +21,7 @@ def gather_logits_before_lm_head( hidden_states = hidden_states.squeeze(int(is_decode_only)) # info object - num_tokens_to_gather, gather_required = logits_gather_info.tolist() + num_tokens_to_gather, gather_required = logits_gather_info_host.tolist() if gather_required: out = hidden_states.index_select(0, logits_gather_indices[:num_tokens_to_gather]) @@ -34,7 +34,7 @@ def gather_logits_before_lm_head( def gather_logits_before_lm_head_fake( hidden_states: torch.Tensor, logits_gather_indices: torch.Tensor, - logits_gather_info: torch.Tensor, + logits_gather_info_host: torch.Tensor, ) -> torch.Tensor: # NOTE: shape is not correct in fake mode # see https://github.com/NVIDIA/TensorRT-LLM/issues/9878 diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py index 5a25b1f1c93..8a9daf75232 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/triton_attention.py @@ -188,7 +188,7 @@ def flattened_mha_with_cache( k: torch.Tensor, v: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, cache_loc: torch.Tensor, @@ -210,7 +210,7 @@ def flattened_mha_with_cache( NOTE: this op can also handle seq_len==0, which might be useful for CUDAGRAPH. """ # check for sequence info and truncate metadata - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode seq_len = seq_len[:num_seq] @@ -290,7 +290,7 @@ def flattened_mha_fake( k: torch.Tensor, v: torch.Tensor, # STANDARD METADATA - batch_info: torch.Tensor, + batch_info_host: torch.Tensor, seq_len: torch.Tensor, input_pos: torch.Tensor, cache_loc: torch.Tensor, @@ -337,7 +337,7 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info", "seq_len", "input_pos", "cache_loc", "cu_seqlen"] + return ["batch_info_host", "seq_len", "input_pos", "cache_loc", "cu_seqlen"] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py index 93090a87783..c237cf9bddd 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/bamba.py @@ -49,15 +49,15 @@ def _bamba_mixer_torch_forward( ) slot_idx_t = torch.arange(batch_size, device=input_states.device, dtype=torch.long) use_initial_states_t = torch.zeros(batch_size, device=input_states.device, dtype=torch.bool) - # batch_info: [num_prefill, num_prefill_tokens, num_decode] + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For context phase (seq_len > 1): [batch_size, batch_size * seq_len, 0] # For generate phase (seq_len == 1): [0, 0, batch_size] if seq_len == 1: - batch_info_t = torch.tensor( + batch_info_host_t = torch.tensor( [0, 0, batch_size], device=input_states.device, dtype=torch.int32 ) else: - batch_info_t = torch.tensor( + batch_info_host_t = torch.tensor( [batch_size, batch_size * seq_len, 0], device=input_states.device, dtype=torch.int32 ) if use_caching: @@ -68,7 +68,7 @@ def _bamba_mixer_torch_forward( self.conv1d.weight, self.conv1d.bias, # STANDARD METADATA - batch_info_t, + batch_info_host_t, seq_len_t, cu_seqlen_t, slot_idx_t, @@ -123,7 +123,7 @@ def _bamba_mixer_torch_forward( dt=dt, dt_bias=self.dt_bias, # STANDARD METADATA - batch_info=batch_info_t, + batch_info_host=batch_info_host_t, seq_len=seq_len_t, cu_seqlen=cu_seqlen_t, slot_idx=slot_idx_t, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py index 9af1d9aa168..690457d6994 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/gather_logits_before_lm_head.py @@ -67,12 +67,14 @@ def _apply( # Add logits_gather_mask as input in the graph and the sequence info interface logits_gather_indices_node = self._add_or_retrieve_input(gm, cm, "logits_gather_indices") - logits_gather_info_node = self._add_or_retrieve_input(gm, cm, "logits_gather_info") + logits_gather_info_host_node = self._add_or_retrieve_input( + gm, cm, "logits_gather_info_host" + ) with gm.graph.inserting_after(node_to_gather): gathered_node = gm.graph.call_function( torch.ops.auto_deploy.gather_logits_before_lm_head.default, - args=(node_to_gather, logits_gather_indices_node, logits_gather_info_node), + args=(node_to_gather, logits_gather_indices_node, logits_gather_info_host_node), ) node_to_gather.replace_all_uses_with(gathered_node) gathered_node.replace_input_with(gathered_node, node_to_gather) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index c8adaa96849..e30df35a3d4 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -220,7 +220,6 @@ def test_bf16(self): @pytest.mark.skip_less_device_memory(32000) def test_fp8(self): kwargs = self.get_default_kwargs() - sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_PATH_FP8, tokenizer=self.MODEL_PATH_FP8, **kwargs) as llm: @@ -228,8 +227,8 @@ def test_fp8(self): llm.args.quant_config.quant_algo = QuantAlgo.FP8 llm.args.quant_config.kv_cache_quant_algo = QuantAlgo.FP8 - task = MMLU(self.MODEL_NAME) - task.evaluate(llm, sampling_params=sampling_params) + # task = MMLU(self.MODEL_NAME) + # task.evaluate(llm, sampling_params=sampling_params) task = GSM8K(self.MODEL_NAME) task.evaluate(llm) diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py b/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py index 66fa4a59c18..3720959cd7d 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/torch_attention_reference.py @@ -40,13 +40,13 @@ def basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions, scale=None) 0, batch_size * seq_len, seq_len, device=q.device, dtype=torch.int32 ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For context phase (seq_len > 1): [batch_size, batch_size * seq_len, 0] # For generate phase (seq_len == 1): [0, 0, batch_size] if seq_len == 1: - batch_info = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32) + batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32) else: - batch_info = torch.tensor( + batch_info_host = torch.tensor( [batch_size, batch_size * seq_len, 0], device=q.device, dtype=torch.int32 ) @@ -60,7 +60,7 @@ def basic_mha_with_cache(q, k, v, k_cache, v_cache, input_positions, scale=None) q_flat, k_flat, v_flat, - batch_info, + batch_info_host, seq_len_tensor, input_positions, cache_loc, @@ -84,7 +84,7 @@ def flattened_mha_with_cache( q, k, v, - batch_info, + batch_info_host, seq_len, input_positions, cache_loc, @@ -101,7 +101,7 @@ def flattened_mha_with_cache( q, k, v, - batch_info, + batch_info_host, seq_len, input_positions, cache_loc, @@ -144,15 +144,15 @@ def decode_with_prefilled_cache(q, k_ref, v_ref, k_cache, v_cache, prefill_lengt k_flat = k_new.view(1, batch_size, -1) v_flat = v_new.view(1, batch_size, -1) - # Create batch_info for decode phase: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32) + # Create batch_info_host for decode phase: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor([0, 0, batch_size], device=q.device, dtype=torch.int32) # Call torch backend via custom op registry output_flat = torch.ops.auto_deploy.torch_cached_attention_with_cache( q_flat, k_flat, v_flat, - batch_info, + batch_info_host, seq_len, input_positions, cache_loc, @@ -170,7 +170,7 @@ def mha_with_features( q, k, v, - batch_info, + batch_info_host, seq_len, input_positions, cache_loc, @@ -189,7 +189,7 @@ def mha_with_features( q, k, v, - batch_info, + batch_info_host, seq_len, input_positions, cache_loc, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py index 15b9eb77c5b..2b143049055 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_attention_op.py @@ -125,9 +125,9 @@ def test_flat_gqa_op( k = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs) v = torch.randn(1, seq_len.sum(), n_kv_heads * D_HEAD, **dtype_kwargs) - # create batch_info: [num_prefill, num_prefill_tokens, num_decode] + # create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] num_prefill_tokens = seq_len[:num_context].sum() - batch_info = torch.tensor([num_context, num_prefill_tokens, num_generate], **int_kwargs) + batch_info_host = torch.tensor([num_context, num_prefill_tokens, num_generate], **int_kwargs) # run op output = torch.ops.auto_deploy.triton_attention_flattened_mha_with_cache( @@ -136,7 +136,7 @@ def test_flat_gqa_op( k, v, # STANDARD METADATA - batch_info, + batch_info_host, seq_len, input_positions, cache_loc, @@ -150,7 +150,7 @@ def test_flat_gqa_op( # Use torch backend as clean reference ref_flat = TorchAttentionReference.flattened_mha_with_cache( - q, k, v, batch_info, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache + q, k, v, batch_info_host, seq_len, input_positions, cache_loc, seq_start, k_cache, v_cache ) assert torch.allclose( diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py index 4e30efdb73b..948daea4cd0 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_cuda_causal_conv_cached_op.py @@ -59,9 +59,9 @@ def test_generate_only_with_slot_mapping_cuda(conv_env): # Metadata (not used in generate-only op entry, but required by the interface) cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32) use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) - # batch_info: [num_prefill, num_prefill_tokens, num_decode] + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For generate-only: num_decode = batch, num_prefill = 0 - batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) + batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) # Snapshot caches for reference before running op (op mutates caches) gathered_before = conv_state_cache.clone().index_select(0, slot_idx) x_ref = x.clone() @@ -72,7 +72,7 @@ def test_generate_only_with_slot_mapping_cuda(conv_env): w, b, # STANDARD METADATA - batch_info, + batch_info_host, cu_seqlen, slot_idx, use_initial_states, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py index 33d0c82bf37..503a780abed 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py @@ -88,8 +88,8 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, ), BATCH_SIZE * SEQ_LEN, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor( + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor( [BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device ) flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( @@ -98,7 +98,7 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, k, v, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr, paged_kv_indptr, paged_kv_indices, @@ -224,16 +224,16 @@ def test_flashinfer_attention_op_decode( ), BATCH_SIZE * SEQ_LEN, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For decode phase: num_decode = BATCH_SIZE, num_prefill = 0 - batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device) + batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device) flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q, k, v, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr, paged_kv_indptr, paged_kv_indices, @@ -347,8 +347,8 @@ def test_flashinfer_attention_context_and_generate( ), BATCH_SIZE * PREFILL_SEQ_LEN, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor( + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor( [BATCH_SIZE, BATCH_SIZE * PREFILL_SEQ_LEN, 0], dtype=torch.int32, device=device ) flashinfer_output_1 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( @@ -357,7 +357,7 @@ def test_flashinfer_attention_context_and_generate( k_1, v_1, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr, paged_kv_indptr, paged_kv_indices, @@ -430,15 +430,15 @@ def test_flashinfer_attention_context_and_generate( ), BATCH_SIZE * 1, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device) + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device) flashinfer_output_3 = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q_3, k_3, v_3, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr, paged_kv_indptr, paged_kv_indices, @@ -543,8 +543,8 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty ), BATCH_SIZE * SEQ_LEN, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor( + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor( [BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device ) flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( @@ -553,7 +553,7 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty k, v, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr, paged_kv_indptr, paged_kv_indices, @@ -696,8 +696,8 @@ def test_flashinfer_attention_with_fp8_cache( ), BATCH_SIZE * SEQ_LEN, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor( + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor( [BATCH_SIZE, BATCH_SIZE * SEQ_LEN, 0], dtype=torch.int32, device=device ) flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( @@ -706,7 +706,7 @@ def test_flashinfer_attention_with_fp8_cache( k, v, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr, paged_kv_indptr, paged_kv_indices, @@ -798,15 +798,15 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de ), SEQ_LEN, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor([BATCH_SIZE, SEQ_LEN, 0], dtype=torch.int32, device=device) + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor([BATCH_SIZE, SEQ_LEN, 0], dtype=torch.int32, device=device) flashinfer_output = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q, k, v, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr, paged_kv_indptr, paged_kv_indices, @@ -885,15 +885,15 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de ), BATCH_SIZE * 1, ) - # Create batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device) + # Create batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor([0, 0, BATCH_SIZE], dtype=torch.int32, device=device) flashinfer_output_gen = torch.ops.auto_deploy.flashinfer_attention_mha_with_cache( # Q, K, V q_gen, k_gen, v_gen, # STANDARD METADATA - batch_info, + batch_info_host, qo_indptr2, paged_kv_indptr2, paged_kv_indices2, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py index 130e7ce651a..655dd7b53a1 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_attention_op.py @@ -246,14 +246,16 @@ def _create_test_data( if seq_len == 1: # Generate phase: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor([0, 0, batch_size], device=self.device, dtype=torch.int32) + batch_info_host = torch.tensor( + [0, 0, batch_size], device=self.device, dtype=torch.int32 + ) seq_start = torch.arange(batch_size, device=self.device, dtype=torch.int32) q_flat = q.view(batch_size, seq_len, -1) k_flat = k.view(batch_size, seq_len, -1) v_flat = v.view(batch_size, seq_len, -1) else: # Context phase: [num_prefill, num_prefill_tokens, num_decode] - batch_info = torch.tensor( + batch_info_host = torch.tensor( [batch_size, batch_size * seq_len, 0], device=self.device, dtype=torch.int32 ) seq_start = torch.arange( @@ -267,7 +269,7 @@ def _create_test_data( "q": q_flat, "k": k_flat, "v": v_flat, - "batch_info": batch_info, + "batch_info_host": batch_info_host, "seq_len": seq_len_tensor, "input_pos": input_positions, "cache_loc": cache_loc, @@ -286,7 +288,7 @@ def _run_attention( data["k"], data["v"], # STANDARD METADATA - data["batch_info"], + data["batch_info_host"], data["seq_len"], data["input_pos"], data["cache_loc"], diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py index 035c3c463ca..ee927c6353d 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_causal_conv_cached_op.py @@ -59,9 +59,9 @@ def test_generate_only_with_slot_mapping(conv_env): # Snapshot caches for reference before running op (op mutates caches) gathered_before = conv_state_cache.clone().index_select(0, slot_idx) use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) - # batch_info: [num_prefill, num_prefill_tokens, num_decode] + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For generate-only: num_decode = batch, num_prefill = 0 - batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) + batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) # Run cached op y = torch.ops.auto_deploy.torch_cached_causal_conv1d( # INPUTS @@ -69,7 +69,7 @@ def test_generate_only_with_slot_mapping(conv_env): w, b, # STANDARD METADATA - batch_info, + batch_info_host, seq_len, cu_seqlen, slot_idx, @@ -124,18 +124,20 @@ def test_context_flattened_and_state_writeback(conv_env): seq_len = torch.tensor(lens, device=device, dtype=torch.int32) cu_seqlen = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) - # batch_info: [num_prefill, num_prefill_tokens, num_decode] + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For context/prefill phase: num_prefill = len(lens), num_decode = 0 num_seqs = len(lens) num_prefill_tokens = sum(lens) - batch_info = torch.tensor([num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32) + batch_info_host = torch.tensor( + [num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32 + ) y = torch.ops.auto_deploy.torch_cached_causal_conv1d( # INPUTS x, w, b, # STANDARD METADATA - batch_info, + batch_info_host, seq_len, cu_seqlen, slot_idx, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py index 39e1a4c1f54..edaf9db0c78 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_torch_mamba_cached_op.py @@ -66,9 +66,9 @@ def test_generate_only_with_slot_mapping(mamba_env): seq_len = torch.ones(batch, device=device, dtype=torch.int32) cu_seqlen = torch.zeros(batch, device=device, dtype=torch.int32) use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) - # batch_info: [num_prefill, num_prefill_tokens, num_decode] + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For generate-only: num_decode = batch, num_prefill = 0 - batch_info = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) + batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) # Snapshot caches for reference before running op (op mutates caches) gathered_before = ssm_state_cache.clone().index_select(0, slot_idx) @@ -83,7 +83,7 @@ def test_generate_only_with_slot_mapping(mamba_env): dt, dt_bias, # STANDARD METADATA - batch_info, + batch_info_host, seq_len, cu_seqlen, slot_idx, @@ -141,11 +141,13 @@ def test_context_flattened_and_state_writeback(mamba_env): seq_len = torch.tensor(lens, device=device, dtype=torch.int32) cu_seqlen = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) - # batch_info: [num_prefill, num_prefill_tokens, num_decode] + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] # For context/prefill phase: num_prefill = len(lens), num_decode = 0 num_seqs = len(lens) num_prefill_tokens = sum(lens) - batch_info = torch.tensor([num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32) + batch_info_host = torch.tensor( + [num_seqs, num_prefill_tokens, 0], device=device, dtype=torch.int32 + ) y = torch.ops.auto_deploy.torch_cached_ssm( # INPUTS hidden_states, @@ -156,7 +158,7 @@ def test_context_flattened_and_state_writeback(mamba_env): dt, dt_bias, # STANDARD METADATA - batch_info, + batch_info_host, seq_len, cu_seqlen, slot_idx, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py index add5cd76be1..d08995017bd 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_triton_mamba_cached_op.py @@ -134,8 +134,8 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): torch.arange(len(lens), device=device, dtype=torch.int32), seq_len, ).view(1, -1) - # batch_info: [num_prefill, num_prefill_tokens, num_decode] - batch_info_tensor = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32, device=device) + # batch_info_host: [num_prefill, num_prefill_tokens, num_decode] + batch_info_host = torch.tensor([len(lens), sum(lens), 0], dtype=torch.int32, device=device) # Torch reference y_torch = torch.ops.auto_deploy.torch_cached_ssm( hidden_states, @@ -146,7 +146,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): dt, dt_bias, # STANDARD METADATA - batch_info_tensor, + batch_info_host, seq_len, cu_seqlen, slot_idx, @@ -168,7 +168,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): dt, dt_bias, # STANDARD METADATA - batch_info_tensor, + batch_info_host, cu_seqlens, slot_idx, use_initial_states, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py index 2f1e7425023..42cd57752c0 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gather_logits_before_lm_head.py @@ -56,10 +56,10 @@ def test_generate_format(self, batch_size): # Create gather info: num_tokens_to_gather=batch_size, gather_required=0 (False) logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda") - logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda") + logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu") output = torch.ops.auto_deploy.gather_logits_before_lm_head.default( - hidden_states, logits_gather_indices, logits_gather_info + hidden_states, logits_gather_indices, logits_gather_info_host ) # Should return [batch, 1, hidden] for generate format (3D shape preserved) @@ -82,10 +82,10 @@ def test_packed_format(self, total_tokens): gather_indices = torch.arange(0, num_gather, dtype=torch.long, device="cuda") # Create gather info: num_tokens_to_gather=num_gather, gather_required=1 (True) - logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda") + logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu") output = torch.ops.auto_deploy.gather_logits_before_lm_head.default( - hidden_states, gather_indices, logits_gather_info + hidden_states, gather_indices, logits_gather_info_host ) # Should return [1, num_gather, hidden] for packed format (3D shape preserved) @@ -105,15 +105,15 @@ def test_fake_implementation_generate_format(self): # Create gather info logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda") - logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda") + logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu") # Use fake implementation directly with FakeTensorMode() as mode: hidden_states_fake = mode.from_tensor(hidden_states) logits_gather_indices_fake = mode.from_tensor(logits_gather_indices) - logits_gather_info_fake = mode.from_tensor(logits_gather_info) + logits_gather_info_host_fake = mode.from_tensor(logits_gather_info_host) output = torch.ops.auto_deploy.gather_logits_before_lm_head.default( - hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake + hidden_states_fake, logits_gather_indices_fake, logits_gather_info_host_fake ) # Should return [batch, 1, hidden_size] (fake returns empty_like which preserves 3D shape) @@ -132,15 +132,15 @@ def test_fake_implementation_packed_format(self): # Create gather info logits_gather_indices = torch.arange(num_gather, dtype=torch.long, device="cuda") - logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda") + logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu") # Use fake implementation directly with FakeTensorMode() as mode: hidden_states_fake = mode.from_tensor(hidden_states) logits_gather_indices_fake = mode.from_tensor(logits_gather_indices) - logits_gather_info_fake = mode.from_tensor(logits_gather_info) + logits_gather_info_host_fake = mode.from_tensor(logits_gather_info_host) output = torch.ops.auto_deploy.gather_logits_before_lm_head.default( - hidden_states_fake, logits_gather_indices_fake, logits_gather_info_fake + hidden_states_fake, logits_gather_indices_fake, logits_gather_info_host_fake ) # The fake implementation returns empty_like which preserves input shape [1, total_tokens, hidden] @@ -217,13 +217,13 @@ def test_transform_generate_format(self, batch_size): # Test forward pass # We must pass the new graph inputs manually since we are running the graph directly logits_gather_indices = torch.arange(batch_size, dtype=torch.long, device="cuda") - logits_gather_info = torch.tensor([batch_size, 0], dtype=torch.int32, device="cuda") + logits_gather_info_host = torch.tensor([batch_size, 0], dtype=torch.int32, device="cpu") output = gm_transformed( hidden_states, logit_gather_ids, seq_len, logits_gather_indices=logits_gather_indices, - logits_gather_info=logits_gather_info, + logits_gather_info_host=logits_gather_info_host, ) # Output should be [batch_size, 1, vocab_size] since gather now returns 3D assert output.shape == (batch_size, 1, vocab_size) @@ -278,13 +278,13 @@ def test_transform_packed_format(self, total_tokens): # We must pass the new graph inputs manually since we are running the graph directly num_gather = len(logit_gather_ids) logits_gather_indices = logit_gather_ids - logits_gather_info = torch.tensor([num_gather, 1], dtype=torch.int32, device="cuda") + logits_gather_info_host = torch.tensor([num_gather, 1], dtype=torch.int32, device="cpu") output = gm_transformed( hidden_states, logit_gather_ids_padded, seq_len, logits_gather_indices=logits_gather_indices, - logits_gather_info=logits_gather_info, + logits_gather_info_host=logits_gather_info_host, ) # Output should be [1, num_gather, vocab_size] since gather now returns 3D assert output.shape == (1, num_gather, vocab_size) From 80257fd861fc246c335533a612026f994b77c32b Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:03:50 -0800 Subject: [PATCH 2/4] separate prefill/decode in flashinfer Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/flashinfer_attention.py | 222 ++++++++++++------ .../_torch/auto_deploy/shim/ad_executor.py | 2 +- .../test_flashinfer_attention_op.py | 104 +++++++- 3 files changed, 250 insertions(+), 78 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index b8281c2caca..d4dfb20871b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -205,14 +205,16 @@ class _FlashInferPlanner: cached_cuda_graph_decode_wrappers: Dict[ PlanParams, flashinfer.BatchDecodeWithPagedKVCacheWrapper ] - plan_params: Optional[PlanParams] + plan_params_prefill: Optional[PlanParams] + plan_params_decode: Optional[PlanParams] def __init__(self): self.workspace_buffer = None self.prefill_wrapper = None self.decode_wrapper = None self.cached_cuda_graph_decode_wrappers = {} - self.plan_params = None + self.plan_params_prefill = None + self.plan_params_decode = None def _init_decode_wrapper( self, @@ -253,7 +255,8 @@ def init_workspace(self, workspace_buffer: torch.Tensor): self.decode_wrapper = self._init_decode_wrapper() def reset(self) -> None: - self.plan_params = None + self.plan_params_prefill = None + self.plan_params_decode = None def plan_generate_only( self, @@ -279,9 +282,42 @@ def plan_generate_only( sm_scale=plan_params.sm_scale, ) - def plan( + def plan_prefill( + self, + qo_indptr_host: torch.Tensor, + kv_page_indptr_host: torch.Tensor, + kv_page_indices: torch.Tensor, + kv_last_page_len_host: torch.Tensor, + kv_lens_arr_host: torch.Tensor, + seq_len_host: torch.Tensor, + plan_params: PlanParams, + ) -> None: + # check for re-planning + if plan_params != self.plan_params_prefill: + # plan prefill + self.prefill_wrapper.plan( + qo_indptr_host, + kv_page_indptr_host, + kv_page_indices, + kv_last_page_len_host, + plan_params.n_heads, # Q heads + plan_params.n_kv_heads, # KV heads + plan_params.head_dim, + plan_params.page_size, + causal=plan_params.causal, + q_data_type=plan_params.q_dtype, + kv_data_type=plan_params.kv_dtype, + sm_scale=plan_params.sm_scale, + # max_token_per_sequence=max(seq_len_host).item(), + seq_lens=kv_lens_arr_host, + ) + self.plan_params_prefill = plan_params + + # return prefill wrapper + return self.prefill_wrapper + + def plan_decode( self, - qo_indptr: torch.Tensor, kv_page_indptr: torch.Tensor, kv_page_indices: torch.Tensor, kv_last_page_len: torch.Tensor, @@ -328,29 +364,12 @@ def _plan_decode( return wrapper # check for re-planning - if plan_params != self.plan_params: - if plan_params.is_generate: - _plan_decode(self.decode_wrapper) - else: - # plan prefill - self.prefill_wrapper.plan( - qo_indptr, - kv_page_indptr, - kv_page_indices, - kv_last_page_len, - plan_params.n_heads, # Q heads - plan_params.n_kv_heads, # KV heads - plan_params.head_dim, - plan_params.page_size, - causal=plan_params.causal, - q_data_type=plan_params.q_dtype, - kv_data_type=plan_params.kv_dtype, - sm_scale=plan_params.sm_scale, - ) - self.plan_params = plan_params + if plan_params != self.plan_params_decode: + _plan_decode(self.decode_wrapper) + self.plan_params_decode = plan_params - # return desired wrapper - return self.decode_wrapper if plan_params.is_generate else self.prefill_wrapper + # return decode wrapper + return self.decode_wrapper _GlobalFlashInferPlanner = _FlashInferPlanner() @@ -412,10 +431,14 @@ def flashinfer_mha_with_cache( v: torch.Tensor, # STANDARD METADATA batch_info_host: torch.Tensor, - cu_seqlen: torch.Tensor, + cu_seqlen_host: torch.Tensor, cu_num_pages: torch.Tensor, + cu_num_pages_host: torch.Tensor, cache_loc: torch.Tensor, last_page_len: torch.Tensor, + last_page_len_host: torch.Tensor, + seq_len_with_cache_host: torch.Tensor, + seq_len_host: torch.Tensor, # EXTRA METADATA flashinfer_batch_indices: torch.Tensor, flashinfer_positions: torch.Tensor, @@ -441,30 +464,11 @@ def flashinfer_mha_with_cache( # convert to flashinfer-style metadata num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode - - qo_indptr = cu_seqlen[: num_seq + 1] - paged_kv_indptr = cu_num_pages[: num_seq + 1] - - # NOTE: it is okay to have cache_loc here without truncation. paged_kv_indptr will be - # truncated and will point to the correct sub range of cache_loc. - paged_kv_indices = cache_loc - paged_kv_last_page_len = last_page_len[:num_seq] + num_total_tokens = num_prefill_tokens + num_decode n_heads = q.shape[1] n_kv_heads = k.shape[1] - pp = PlanParams( - n_heads=n_heads, - n_kv_heads=n_kv_heads, - head_dim=head_dim, - num_seq=len(qo_indptr) - 1, - is_generate=(s == 1), - page_size=k_cache.shape[1], - q_dtype=q.dtype, - kv_dtype=k_cache.dtype, - sm_scale=scale, - ) - # Assuming k_scale = v_scale = 1.0 k_scale, v_scale = 1.0, 1.0 # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v @@ -473,28 +477,94 @@ def flashinfer_mha_with_cache( v = v.to(torch.float8_e4m3fn) flashinfer.page.append_paged_kv_cache( - k, - v, - flashinfer_batch_indices, - flashinfer_positions, - (k_cache, v_cache), - paged_kv_indices, - paged_kv_indptr, - paged_kv_last_page_len, + append_key=k, + append_value=v, + batch_indices=flashinfer_batch_indices, + positions=flashinfer_positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=cache_loc, + kv_indptr=cu_num_pages[: num_seq + 1], + kv_last_page_len=last_page_len[:num_seq], ) - # run the flashinfer planner and obtain the correct wrapper - wrapper = _GlobalFlashInferPlanner.plan( - qo_indptr, - paged_kv_indptr, - paged_kv_indices, - paged_kv_last_page_len, - pp, - ) + # check if we need to re-combine outputs + if num_prefill > 0 and num_decode > 0: + y = torch.empty_like(q) + else: + y = None + + # now run split prefill, decode + if num_prefill > 0: + q_prefill = q[:num_prefill_tokens] + + pp_prefill = PlanParams( + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + num_seq=num_prefill, + is_generate=False, + page_size=k_cache.shape[1], + q_dtype=q_prefill.dtype, + kv_dtype=k_cache.dtype, + sm_scale=scale, + ) - y = wrapper.run( - q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale, enable_pdl=get_env_enable_pdl() - ) + wrapper_prefill = _GlobalFlashInferPlanner.plan_prefill( + qo_indptr_host=cu_seqlen_host[: num_prefill + 1], + kv_page_indptr_host=cu_num_pages_host[: num_prefill + 1], + kv_page_indices=cache_loc, + kv_last_page_len_host=last_page_len_host[:num_prefill], + kv_lens_arr_host=seq_len_with_cache_host[:num_prefill], + seq_len_host=seq_len_host[:num_prefill], + plan_params=pp_prefill, + ) + + y_prefill = wrapper_prefill.run( + q_prefill, + (k_cache, v_cache), + k_scale=k_scale, + v_scale=v_scale, + enable_pdl=get_env_enable_pdl(), + ) + if y is not None: + y[:num_prefill_tokens] = y_prefill + else: + y = y_prefill + + if num_decode > 0: + q_decode = q[num_prefill_tokens:num_total_tokens] + + pp_decode = PlanParams( + n_heads=n_heads, + n_kv_heads=n_kv_heads, + head_dim=head_dim, + num_seq=num_decode, + is_generate=True, + page_size=k_cache.shape[1], + q_dtype=q_decode.dtype, + kv_dtype=k_cache.dtype, + sm_scale=scale, + ) + + # run the flashinfer planner and obtain the correct wrapper + wrapper_decode = _GlobalFlashInferPlanner.plan_decode( + kv_page_indptr=cu_num_pages[num_prefill : num_seq + 1], + kv_page_indices=cache_loc, + kv_last_page_len=last_page_len[num_prefill:num_seq], + plan_params=pp_decode, + ) + + y_decode = wrapper_decode.run( + q_decode, + (k_cache, v_cache), + k_scale=k_scale, + v_scale=v_scale, + enable_pdl=get_env_enable_pdl(), + ) + if y is not None: + y[num_prefill_tokens:num_total_tokens] = y_decode + else: + y = y_decode return y.view(q_shape_og) # [b,s,n*h_d] or [b,s, n, h_d] @@ -507,10 +577,14 @@ def flashinfer_mha_with_cache_fake( v: torch.Tensor, # STANDARD METADATA batch_info_host: torch.Tensor, - cu_seqlen: torch.Tensor, + cu_seqlen_host: torch.Tensor, cu_num_pages: torch.Tensor, + cu_num_pages_host: torch.Tensor, cache_loc: torch.Tensor, last_page_len: torch.Tensor, + last_page_len_host: torch.Tensor, + seq_len_with_cache_host: torch.Tensor, + seq_len_host: torch.Tensor, # EXTRA METADATA flashinfer_batch_indices: torch.Tensor, flashinfer_positions: torch.Tensor, @@ -559,7 +633,17 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info_host", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"] + return [ + "batch_info_host", + "cu_seqlen_host", + "cu_num_pages", + "cu_num_pages_host", + "cache_loc", + "last_page_len", + "last_page_len_host", + "seq_len_with_cache_host", + "seq_len_host", + ] @classmethod def get_prepare_extra_metadata_info( diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index cf66991e9f4..ffd39896ac7 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -379,7 +379,7 @@ def _call_func(): # check if we have a dummy request to use if self.padding_dummy_request is None: - ad_logger.error("No CUDA graph padding possible due to missing dummy request.") + ad_logger.info("No CUDA graph padding possible due to missing dummy request.") return _call_func() # pad the scheduled requests with the dummy request diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py index 503a780abed..9ff4c2e5d6c 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py @@ -64,6 +64,13 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) paged_kv_last_page_len = offsets + seq_len_tensor + # Host copies of metadata + qo_indptr_host = qo_indptr.cpu() + paged_kv_indptr_host = paged_kv_indptr.cpu() + paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() + seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() + seq_len_host = seq_len_tensor.cpu() + # Q,K,V are computed using GEMM. q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -99,10 +106,14 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, v, # STANDARD METADATA batch_info_host, - qo_indptr, + qo_indptr_host, paged_kv_indptr, + paged_kv_indptr_host, paged_kv_indices, paged_kv_last_page_len, + paged_kv_last_page_len_host, + seq_len_with_cache_host, + seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -168,6 +179,13 @@ def test_flashinfer_attention_op_decode( paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) paged_kv_last_page_len = offsets + seq_len_tensor + # Host copies of metadata + qo_indptr_host = qo_indptr.cpu() + paged_kv_indptr_host = paged_kv_indptr.cpu() + paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() + seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() + seq_len_host = seq_len_tensor.cpu() + # Q,K,V are computed using GEMM. q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) k = torch.ones(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -234,10 +252,14 @@ def test_flashinfer_attention_op_decode( v, # STANDARD METADATA batch_info_host, - qo_indptr, + qo_indptr_host, paged_kv_indptr, + paged_kv_indptr_host, paged_kv_indices, paged_kv_last_page_len, + paged_kv_last_page_len_host, + seq_len_with_cache_host, + seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -323,6 +345,13 @@ def test_flashinfer_attention_context_and_generate( paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) paged_kv_last_page_len = offsets + seq_len_tensor + # Host copies of metadata + qo_indptr_host = qo_indptr.cpu() + paged_kv_indptr_host = paged_kv_indptr.cpu() + paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() + seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() + seq_len_host = seq_len_tensor.cpu() + # Q,K,V for prefill phase q_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) k_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -358,10 +387,14 @@ def test_flashinfer_attention_context_and_generate( v_1, # STANDARD METADATA batch_info_host, - qo_indptr, + qo_indptr_host, paged_kv_indptr, + paged_kv_indptr_host, paged_kv_indices, paged_kv_last_page_len, + paged_kv_last_page_len_host, + seq_len_with_cache_host, + seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -415,6 +448,13 @@ def test_flashinfer_attention_context_and_generate( paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) paged_kv_last_page_len = offsets + seq_len_tensor + # Host copies of metadata + qo_indptr_host = qo_indptr.cpu() + paged_kv_indptr_host = paged_kv_indptr.cpu() + paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() + seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() + seq_len_host = seq_len_tensor.cpu() + # Q,K,V are computed using GEMM. q_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device) k_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -439,10 +479,14 @@ def test_flashinfer_attention_context_and_generate( v_3, # STANDARD METADATA batch_info_host, - qo_indptr, + qo_indptr_host, paged_kv_indptr, + paged_kv_indptr_host, paged_kv_indices, paged_kv_last_page_len, + paged_kv_last_page_len_host, + seq_len_with_cache_host, + seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -519,6 +563,13 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) paged_kv_last_page_len = offsets + seq_len_tensor + # Host copies of metadata + qo_indptr_host = qo_indptr.cpu() + paged_kv_indptr_host = paged_kv_indptr.cpu() + paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() + seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() + seq_len_host = seq_len_tensor.cpu() + # Q,K,V are computed using GEMM. q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -554,10 +605,14 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty v, # STANDARD METADATA batch_info_host, - qo_indptr, + qo_indptr_host, paged_kv_indptr, + paged_kv_indptr_host, paged_kv_indices, paged_kv_last_page_len, + paged_kv_last_page_len_host, + seq_len_with_cache_host, + seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -642,6 +697,13 @@ def test_flashinfer_attention_with_fp8_cache( paged_kv_indices = torch.arange(BATCH_SIZE).int().to(device) paged_kv_last_page_len = offsets + seq_len_tensor + # Host copies of metadata + qo_indptr_host = qo_indptr.cpu() + paged_kv_indptr_host = paged_kv_indptr.cpu() + paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() + seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() + seq_len_host = seq_len_tensor.cpu() + # Q,K,V are computed using GEMM, in fp16 q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) k = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -707,10 +769,14 @@ def test_flashinfer_attention_with_fp8_cache( v, # STANDARD METADATA batch_info_host, - qo_indptr, + qo_indptr_host, paged_kv_indptr, + paged_kv_indptr_host, paged_kv_indices, paged_kv_last_page_len, + paged_kv_last_page_len_host, + seq_len_with_cache_host, + seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -787,6 +853,13 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de ) paged_kv_last_page_len = ((offsets + seq_len_tensor - 1) % PAGE_SIZE) + 1 + # Host copies of metadata + qo_indptr_host = qo_indptr.cpu() + paged_kv_indptr_host = paged_kv_indptr.cpu() + paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() + seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() + seq_len_host = seq_len_tensor.cpu() + # make sure planner is initialized workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) _GlobalFlashInferPlanner.init_workspace(workspace) @@ -807,10 +880,14 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de v, # STANDARD METADATA batch_info_host, - qo_indptr, + qo_indptr_host, paged_kv_indptr, + paged_kv_indptr_host, paged_kv_indices, paged_kv_last_page_len, + paged_kv_last_page_len_host, + seq_len_with_cache_host, + seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -875,6 +952,13 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de ) paged_kv_last_page_len2 = ((offsets2 + seq_len_tensor2 - 1) % PAGE_SIZE) + 1 + # Host copies of metadata + qo_indptr2_host = qo_indptr2.cpu() + paged_kv_indptr2_host = paged_kv_indptr2.cpu() + paged_kv_last_page_len2_host = paged_kv_last_page_len2.cpu() + seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu() + seq_len2_host = seq_len_tensor2.cpu() + # Create FlashInferAttention class before calling the custom op _GlobalFlashInferPlanner.reset() @@ -894,10 +978,14 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de v_gen, # STANDARD METADATA batch_info_host, - qo_indptr2, + qo_indptr2_host, paged_kv_indptr2, + paged_kv_indptr2_host, paged_kv_indices2, paged_kv_last_page_len2, + paged_kv_last_page_len2_host, + seq_len_with_cache2_host, + seq_len2_host, # EXTRA METADATA batch_indices, positions, From 500515de0d2f6bba9ef6aeaf48a55beb9a912f3e Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 30 Dec 2025 09:52:18 -0800 Subject: [PATCH 3/4] feedback and updates Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/attention_interface.py | 49 +++++++------------ .../custom_ops/flashinfer_attention.py | 48 +++++++++--------- .../auto_deploy/transform/library/kvcache.py | 27 ++++++++-- 3 files changed, 67 insertions(+), 57 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 750c9085674..ea583d84aae 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -10,19 +10,7 @@ """ from abc import ABC, abstractmethod -from typing import ( - Callable, - Dict, - List, - Literal, - Optional, - Protocol, - Sequence, - Set, - Tuple, - Type, - Union, -) +from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union import torch from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -36,6 +24,10 @@ Constant = Union[int, float, str, None] +class PrepareMetadataHostCallable(Protocol): + def __call__(self, **sequence_info_args: torch.Tensor) -> None: ... + + class InputBuffer: """Manages contiguous memory buffers for efficient host-to-device transfers. @@ -388,6 +380,9 @@ class SequenceInfo: - _mask_scatter_indices: [m_0, m_1, ..., m_{s_total-1}] Mask scatter indices used by the overlap scheduler to scatter results back. + NOTE: all tensors are also accessible as host tensors with the suffix "_host". For example, + the tensor "batch_info" is accessible as "batch_info_host" on the host. + ################################################################################################ Here are a couple of notes to emphasize this notation: @@ -526,7 +521,7 @@ def __init__( ############################################################################################ # HOST PREPARE FOR ATTENTION FORWARD ####################################################### - self._host_prepare_functions: set[Callable[[SequenceInfo], None]] = set() + self._host_prepare_functions: List[Tuple[PrepareMetadataHostCallable, List[str]]] = [] # call reset once to set a consistent initial state self.reset() @@ -1043,13 +1038,13 @@ def unnest_sequences(self, t_nested: torch.Tensor) -> List[torch.Tensor]: return list(torch.split(t_squeezed, self.seq_len)) def register_host_prepare_for_attention_forward( - self, host_function: Callable[["SequenceInfo"], None] + self, host_function: PrepareMetadataHostCallable, args: List[str] ): - self._host_prepare_functions.add(host_function) + self._host_prepare_functions.append((host_function, args)) def run_host_prepare_for_attention_forward(self) -> None: - for host_function in self._host_prepare_functions: - host_function(self) + for host_function, args in self._host_prepare_functions: + host_function(**{arg: self._get_arg(arg) for arg in args}) class MHACallable(Protocol): @@ -1061,14 +1056,7 @@ def __call__( class PrepareMetadataCallable(Protocol): def __call__( - self, - position_ids: torch.Tensor, - seq_len: torch.Tensor, - input_pos: torch.Tensor, - cache_loc: torch.Tensor, - pages_per_seq: torch.Tensor, - slot_idx: torch.Tensor, - page_size: int, + self, *sequence_info_args_and_constants: Union[torch.Tensor, Constant] ) -> List[torch.Tensor]: ... @@ -1229,13 +1217,14 @@ def get_constants(cls, source_attn_node: Node) -> List[Constant]: return [] @classmethod - def host_prepare_for_forward(cls, sequence_info: SequenceInfo): - """Perform host-side preparation for the forward pass for the attention op. + def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]: + """Get function that performs host-side prep for the forward pass for the attention op. This method is responsible for preparing the attention op for the forward pass. - This function is not expected to be graph capturable or compatible with cuda graphs. + This function is not expected to be graph capturable or compatible with cuda graphs. It can + use any argument from the SequenceInfo interface as input argument to its function. """ - return + return None class AttentionRegistry: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index d4dfb20871b..c8c2f835156 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -21,6 +21,7 @@ Constant, MHACallable, PrepareMetadataCallable, + PrepareMetadataHostCallable, SequenceInfo, ) @@ -183,7 +184,6 @@ class PlanParams: n_kv_heads: int head_dim: int num_seq: int - is_generate: bool page_size: int q_dtype: torch.dtype kv_dtype: torch.dtype @@ -289,12 +289,17 @@ def plan_prefill( kv_page_indices: torch.Tensor, kv_last_page_len_host: torch.Tensor, kv_lens_arr_host: torch.Tensor, - seq_len_host: torch.Tensor, plan_params: PlanParams, ) -> None: # check for re-planning if plan_params != self.plan_params_prefill: # plan prefill + # NOTE (lucaslie): we use host versions here. the plan actually needs both (host+device) + # version. Unfortunately, there is no good way to access the plan API and provide both + # although we have both available. I have decided to use the host versions here to + # ensure non-blocking invocation of plan, whereas the other way around would trigger a + # blocking copy to cpu. This way we trigger a non-blocking copy to device (note that + # this is safe since we do have pinned CPU memory for all our host-side arguments). self.prefill_wrapper.plan( qo_indptr_host, kv_page_indptr_host, @@ -308,7 +313,6 @@ def plan_prefill( q_data_type=plan_params.q_dtype, kv_data_type=plan_params.kv_dtype, sm_scale=plan_params.sm_scale, - # max_token_per_sequence=max(seq_len_host).item(), seq_lens=kv_lens_arr_host, ) self.plan_params_prefill = plan_params @@ -359,7 +363,6 @@ def _plan_decode( _plan_decode(self.cached_cuda_graph_decode_wrappers[plan_params]) # check if we are in cuda graph capture and just return the pre-cached decode wrapper if torch.cuda.is_current_stream_capturing() or cuda_graph_state.in_warm_up(): - assert plan_params.is_generate, "Only generate is supported during cuda graph capture." wrapper = self.cached_cuda_graph_decode_wrappers[plan_params] return wrapper @@ -423,6 +426,23 @@ def prepare_flashinfer_metadata_fake( ) +def prepare_flashinfer_metadata_host( + batch_info_host: torch.Tensor, + cu_num_pages_host: torch.Tensor, + cache_loc_host: torch.Tensor, + last_page_len_host: torch.Tensor, +) -> None: + num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() + + if num_prefill == 0: + _GlobalFlashInferPlanner.plan_generate_only( + num_decode, + cu_num_pages_host[: num_decode + 1], + cache_loc_host, + last_page_len_host[:num_decode], + ) + + @torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=()) def flashinfer_mha_with_cache( # Q, K, V @@ -438,7 +458,6 @@ def flashinfer_mha_with_cache( last_page_len: torch.Tensor, last_page_len_host: torch.Tensor, seq_len_with_cache_host: torch.Tensor, - seq_len_host: torch.Tensor, # EXTRA METADATA flashinfer_batch_indices: torch.Tensor, flashinfer_positions: torch.Tensor, @@ -502,7 +521,6 @@ def flashinfer_mha_with_cache( n_kv_heads=n_kv_heads, head_dim=head_dim, num_seq=num_prefill, - is_generate=False, page_size=k_cache.shape[1], q_dtype=q_prefill.dtype, kv_dtype=k_cache.dtype, @@ -515,7 +533,6 @@ def flashinfer_mha_with_cache( kv_page_indices=cache_loc, kv_last_page_len_host=last_page_len_host[:num_prefill], kv_lens_arr_host=seq_len_with_cache_host[:num_prefill], - seq_len_host=seq_len_host[:num_prefill], plan_params=pp_prefill, ) @@ -539,7 +556,6 @@ def flashinfer_mha_with_cache( n_kv_heads=n_kv_heads, head_dim=head_dim, num_seq=num_decode, - is_generate=True, page_size=k_cache.shape[1], q_dtype=q_decode.dtype, kv_dtype=k_cache.dtype, @@ -584,7 +600,6 @@ def flashinfer_mha_with_cache_fake( last_page_len: torch.Tensor, last_page_len_host: torch.Tensor, seq_len_with_cache_host: torch.Tensor, - seq_len_host: torch.Tensor, # EXTRA METADATA flashinfer_batch_indices: torch.Tensor, flashinfer_positions: torch.Tensor, @@ -642,7 +657,6 @@ def get_standard_metadata_args(cls) -> List[str]: "last_page_len", "last_page_len_host", "seq_len_with_cache_host", - "seq_len_host", ] @classmethod @@ -684,18 +698,8 @@ def _init_workspace(si: SequenceInfo) -> torch.Tensor: return {"workspace_buffer": _init_workspace} @classmethod - def host_prepare_for_forward(cls, sequence_info: SequenceInfo): - batch_info = sequence_info._input_buffer.get_host_view("batch_info") - num_prefill, num_prefill_tokens, num_decode = batch_info.tolist() - # Call plan for generate-only batches. - if num_prefill == 0: - _GlobalFlashInferPlanner.plan_generate_only( - num_decode, - sequence_info._input_buffer.get_host_view("cu_num_pages")[: num_decode + 1], - sequence_info._input_buffer.get_host_view("cache_loc"), - sequence_info._input_buffer.get_host_view("last_page_len")[:num_decode], - ) - return + def get_host_prepare_metadata_function(cls) -> Optional[PrepareMetadataHostCallable]: + return prepare_flashinfer_metadata_host @classmethod def get_constants(cls, source_attn_node: Node) -> List[Constant]: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py index 60cd26f8778..1207bda245a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py @@ -1,5 +1,6 @@ """Graph transformation to automatically add kv cache into fused MHA op.""" +import inspect import operator from typing import Dict, List, Optional, Tuple, Type @@ -106,6 +107,23 @@ def _process_metadata_extra( gm, prep_meta_op, inputs_for_prep_meta, const_args, num_meta_out ) + def _process_metadata_host(self, cm: CachedSequenceInterface): + """Process the host-side prepare metadata function.""" + prep_meta_host_op = self.attn_descriptor.get_host_prepare_metadata_function() + if prep_meta_host_op is None: + return + + # analyze the args of the host-side prepare metadata function using inspect + sig = inspect.signature(prep_meta_host_op) + args = sig.parameters.keys() + + # check if all args are available in the cached sequence interface + unavailable_args = args - cm.info.available_args + assert not unavailable_args, f"Missing args in SequenceInfo: {unavailable_args=}" + + # add the host-side prepare metadata function to the graph + cm.info.register_host_prepare_for_attention_forward(prep_meta_host_op, list(args)) + def _process_cache_node(self, gm: GraphModule, cache_name: str) -> Node: """Process the cache nodes by inserting a cached attention replacement op.""" return add_graph_input(gm, cache_name) @@ -173,6 +191,9 @@ def _apply( # insert metadata computation and extract each argument as a node meta_nodes_extra = self._process_metadata_extra(gm, cm, source_attn_nodes[0]) + # Register host-side prepare_metadata function for attention descriptor. + self._process_metadata_host(cm) + buffer_in_lookup: Dict[str, Node] = {} # replace fused attention node with attention node that has kv cache @@ -213,11 +234,7 @@ def _apply( buffer_in_nodes, constants, ) - # Attention descriptor should register its host function with SequenceInfo. - # This function will be called before graph invocation. - cm.info.register_host_prepare_for_attention_forward( - attn_descriptor.host_prepare_for_forward - ) + num_cached_attn_replacements += 1 info = TransformInfo( From 0dc78ead98302ee91de568ac0e4b6976dad14807 Mon Sep 17 00:00:00 2001 From: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> Date: Tue, 30 Dec 2025 10:17:56 -0800 Subject: [PATCH 4/4] update unit test Signed-off-by: Lucas Liebenwein <11156568+lucaslie@users.noreply.github.com> --- .../custom_ops/test_flashinfer_attention_op.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py index 9ff4c2e5d6c..e27477112fa 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py @@ -69,7 +69,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, paged_kv_indptr_host = paged_kv_indptr.cpu() paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() - seq_len_host = seq_len_tensor.cpu() # Q,K,V are computed using GEMM. q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -113,7 +112,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype, paged_kv_last_page_len, paged_kv_last_page_len_host, seq_len_with_cache_host, - seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -184,7 +182,6 @@ def test_flashinfer_attention_op_decode( paged_kv_indptr_host = paged_kv_indptr.cpu() paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() - seq_len_host = seq_len_tensor.cpu() # Q,K,V are computed using GEMM. q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -259,7 +256,6 @@ def test_flashinfer_attention_op_decode( paged_kv_last_page_len, paged_kv_last_page_len_host, seq_len_with_cache_host, - seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -350,7 +346,6 @@ def test_flashinfer_attention_context_and_generate( paged_kv_indptr_host = paged_kv_indptr.cpu() paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() - seq_len_host = seq_len_tensor.cpu() # Q,K,V for prefill phase q_1 = torch.randn(BATCH_SIZE, PREFILL_SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -394,7 +389,6 @@ def test_flashinfer_attention_context_and_generate( paged_kv_last_page_len, paged_kv_last_page_len_host, seq_len_with_cache_host, - seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -453,7 +447,6 @@ def test_flashinfer_attention_context_and_generate( paged_kv_indptr_host = paged_kv_indptr.cpu() paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() - seq_len_host = seq_len_tensor.cpu() # Q,K,V are computed using GEMM. q_3 = torch.randn(BATCH_SIZE, 1, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -486,7 +479,6 @@ def test_flashinfer_attention_context_and_generate( paged_kv_last_page_len, paged_kv_last_page_len_host, seq_len_with_cache_host, - seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -568,7 +560,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty paged_kv_indptr_host = paged_kv_indptr.cpu() paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() - seq_len_host = seq_len_tensor.cpu() # Q,K,V are computed using GEMM. q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -612,7 +603,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty paged_kv_last_page_len, paged_kv_last_page_len_host, seq_len_with_cache_host, - seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -702,7 +692,6 @@ def test_flashinfer_attention_with_fp8_cache( paged_kv_indptr_host = paged_kv_indptr.cpu() paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() - seq_len_host = seq_len_tensor.cpu() # Q,K,V are computed using GEMM, in fp16 q = torch.randn(BATCH_SIZE, SEQ_LEN, N_HEADS * D_HEAD, dtype=DTYPE).to(device) @@ -776,7 +765,6 @@ def test_flashinfer_attention_with_fp8_cache( paged_kv_last_page_len, paged_kv_last_page_len_host, seq_len_with_cache_host, - seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -858,7 +846,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de paged_kv_indptr_host = paged_kv_indptr.cpu() paged_kv_last_page_len_host = paged_kv_last_page_len.cpu() seq_len_with_cache_host = (offsets + seq_len_tensor).cpu() - seq_len_host = seq_len_tensor.cpu() # make sure planner is initialized workspace = torch.empty(128 * 1024 * 1024, dtype=torch.uint8, device=device) @@ -887,7 +874,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de paged_kv_last_page_len, paged_kv_last_page_len_host, seq_len_with_cache_host, - seq_len_host, # EXTRA METADATA batch_indices, positions, @@ -957,7 +943,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de paged_kv_indptr2_host = paged_kv_indptr2.cpu() paged_kv_last_page_len2_host = paged_kv_last_page_len2.cpu() seq_len_with_cache2_host = (offsets2 + seq_len_tensor2).cpu() - seq_len2_host = seq_len_tensor2.cpu() # Create FlashInferAttention class before calling the custom op _GlobalFlashInferPlanner.reset() @@ -985,7 +970,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de paged_kv_last_page_len2, paged_kv_last_page_len2_host, seq_len_with_cache2_host, - seq_len2_host, # EXTRA METADATA batch_indices, positions,