Skip to content

Commit f8ff684

Browse files
committed
attention i/f providing device and host argument
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent aaa87ab commit f8ff684

23 files changed

+149
-204
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 10 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -496,16 +496,17 @@ def __init__(
496496
# Create the InputBuffer that manages contiguous host and device memory
497497
# Starts on default device; use to() to move to target device
498498
self._input_buffer = InputBuffer(tensor_specs)
499+
self._available_args = set(self._input_buffer.tensor_names) | {
500+
f"{name}_host" for name in self._input_buffer.tensor_names
501+
}
499502

500503
# Initialize args_list from tensor specs
501504
self._args_list: Dict[str, List[int]] = {
502505
name: [0] * numel for name, numel, _ in tensor_specs
503506
}
504507

505508
self._active_args = ("input_ids", "position_ids")
506-
self._shapeable_args = ("input_ids", "position_ids")
507-
# Args that should be returned from host (pinned memory) instead of device in _named_args
508-
self._host_return_args = ("batch_info", "logits_gather_info")
509+
self._shapeable_args = ("input_ids", "position_ids", "input_ids_host", "position_ids_host")
509510
############################################################################################
510511

511512
# EXTRA TENSOR FIELDS ######################################################################
@@ -543,14 +544,13 @@ def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
543544

544545
def _get_arg(self, name: str) -> torch.Tensor:
545546
"""Get the argument from the input buffer either on device or host."""
546-
if name in self._host_return_args:
547-
arg = self._input_buffer.get_host_view(name)
547+
if name.endswith("_host"):
548+
arg = self._input_buffer.get_host_view(name.replace("_host", ""))
548549
else:
549550
arg = self._input_buffer.get_view(name)
550551
return self._shape_for_forward(arg) if name in self._shapeable_args else arg
551552

552553
def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
553-
# Build args dict, using host views for _host_return_args, device views otherwise
554554
args = {k: self._get_arg(k) for k in self._active_args}
555555

556556
# check other args to include
@@ -562,7 +562,7 @@ def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor
562562
@property
563563
def available_args(self) -> Set[str]:
564564
"""Return a list of available arguments."""
565-
return set(self._input_buffer.tensor_names)
565+
return self._available_args
566566

567567
@property
568568
def named_args(self) -> Dict[str, torch.Tensor]:
@@ -682,68 +682,6 @@ def _get_cache_locations_and_pages_per_sequence(
682682
pages_per_seq = [len(p) for p in page_assignments]
683683
return cache_loc_flat, pages_per_seq
684684

685-
# TODO: remove after updating all cached backends
686-
@classmethod
687-
def _get_sanitized_seq_len(
688-
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
689-
) -> torch.Tensor:
690-
"""Sanitize sequence lengths.
691-
692-
We want to cover the following scenarios with this function:
693-
694-
1. Pre-fill:
695-
input_ids: [1, s_total, ...]
696-
seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
697-
---> returns [s_0, s_1, ..., s_{b-1}]
698-
2. Decode:
699-
input_ids: [b, 1, ...]
700-
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
701-
|---- b ----|--- (max_batch_size - b) ---|
702-
--> returns [1,] * b
703-
3. Decode in Cudagraph:
704-
input_ids: [b_cudagraph, 1, ...]
705-
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
706-
|---- b ----|--- (max_batch_size - b) ---|
707-
708-
--> returns [1,] * b_cudagraph
709-
Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
710-
b_cudagraph.
711-
712-
# TODO: I could see one possible issue with this approach in the future.
713-
# If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
714-
# information. What could happen is that the for the padded sequences the cache location
715-
# tensors point to allocated pages. This could lead to a situation where we write into
716-
# allocated cache pages polluting the cache of other sequences. Now this is not an issue
717-
# if we write the dummy sequences into unallocated cache pages... One fix could be to
718-
# pad not only the seq len but also pad the cache locations by just repeating the last
719-
# valid cache location in the batch. This would ensure that the dummy sequences just
720-
# repeats valid computation...
721-
"""
722-
_, s = input_or_position_ids.shape[:2]
723-
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
724-
if s > 1:
725-
return seq_len[:num_seq].clone()
726-
else:
727-
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
728-
729-
@staticmethod
730-
def _get_sanitized_num_sequences(
731-
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
732-
) -> int:
733-
"""Get number of sequences.
734-
735-
We makes sure that this function is compatible with both torch graph capture and cudagraph.
736-
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
737-
with max_batch_size or max_batch_size*max_seq_len.
738-
"""
739-
b, s = input_or_position_ids.shape[:2]
740-
if s > 1:
741-
num_seq = torch.sum(seq_len > 0)
742-
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
743-
else:
744-
num_seq = b
745-
return num_seq
746-
747685
def activate_arg(self, arg_name: str) -> bool:
748686
"""Activate a desired argument.
749687
@@ -854,7 +792,7 @@ def _store_arg(
854792
self._args_list[name] = tnsr_like.copy()
855793

856794
# Only store to buffer when the argument is active or force_copy is True
857-
if not (name in self._active_args or force_copy):
795+
if not (name in self._active_args or f"{name}_host" in self._active_args or force_copy):
858796
return
859797

860798
# Store to the InputBuffer's pinned host memory
@@ -1075,12 +1013,12 @@ def rescatter_input_ids(self, ungathered_input_ids: torch.Tensor):
10751013
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
10761014
"""Maybe gather the logits if logits have not been gathered yet."""
10771015
num_tokens = logits.shape[0] * logits.shape[1]
1078-
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
1016+
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info_host").tolist()
10791017
if gather_required and num_tokens_to_gather < num_tokens:
10801018
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
10811019
logits,
10821020
self._get_arg("logits_gather_indices"),
1083-
self._get_arg("logits_gather_info"),
1021+
self._get_arg("logits_gather_info_host"),
10841022
)
10851023
return logits.squeeze(int(self.is_generate))
10861024

tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def fla_cached_delta_rule(
3535
v: torch.Tensor,
3636
beta: torch.Tensor,
3737
# STANDARD METADATA
38-
batch_info: torch.Tensor,
38+
batch_info_host: torch.Tensor,
3939
cu_seqlen: torch.Tensor,
4040
slot_idx: torch.Tensor,
4141
use_initial_states: torch.Tensor,
@@ -58,7 +58,7 @@ def fla_cached_delta_rule(
5858
y = torch.empty_like(v, memory_format=torch.contiguous_format)
5959
y_flat = y.view(b * s, num_heads, -1)
6060

61-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
61+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
6262
num_seq = num_prefill + num_decode
6363

6464
# clean up metadata
@@ -120,7 +120,7 @@ def fla_cached_delta_rule_fake(
120120
v: torch.Tensor,
121121
beta: torch.Tensor,
122122
# STANDARD METADATA
123-
batch_info: torch.Tensor,
123+
batch_info_host: torch.Tensor,
124124
cu_seqlen: torch.Tensor,
125125
slot_idx: torch.Tensor,
126126
use_initial_states: torch.Tensor,
@@ -160,7 +160,7 @@ def get_cached_attention_op(cls) -> MHACallable:
160160

161161
@classmethod
162162
def get_standard_metadata_args(cls) -> List[str]:
163-
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
163+
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
164164

165165
@classmethod
166166
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _plan_decode(wrapper: flashinfer.BatchDecodeWithPagedKVCacheWrapper):
157157
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
158158
def prepare_flashinfer_metadata(
159159
position_ids: torch.Tensor,
160-
batch_info: torch.Tensor,
160+
batch_info_host: torch.Tensor,
161161
cu_seqlen: torch.Tensor,
162162
seq_len_with_cache: torch.Tensor,
163163
) -> List[torch.Tensor]:
@@ -171,7 +171,7 @@ def prepare_flashinfer_metadata(
171171
_GlobalFlashInferPlanner.reset()
172172

173173
# retrieve host-side metadata
174-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
174+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
175175
num_seq = num_prefill + num_decode
176176
num_tokens = num_prefill_tokens + num_decode
177177

@@ -192,7 +192,7 @@ def prepare_flashinfer_metadata(
192192
@prepare_flashinfer_metadata.register_fake
193193
def prepare_flashinfer_metadata_fake(
194194
position_ids: torch.Tensor,
195-
batch_info: torch.Tensor,
195+
batch_info_host: torch.Tensor,
196196
cu_seqlen: torch.Tensor,
197197
seq_len_with_cache: torch.Tensor,
198198
):
@@ -210,7 +210,7 @@ def flashinfer_mha_with_cache(
210210
k: torch.Tensor,
211211
v: torch.Tensor,
212212
# STANDARD METADATA
213-
batch_info: torch.Tensor,
213+
batch_info_host: torch.Tensor,
214214
cu_seqlen: torch.Tensor,
215215
cu_num_pages: torch.Tensor,
216216
cache_loc: torch.Tensor,
@@ -238,7 +238,7 @@ def flashinfer_mha_with_cache(
238238
v = v.reshape(b * s, -1, head_dim)
239239

240240
# convert to flashinfer-style metadata
241-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
241+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
242242
num_seq = num_prefill + num_decode
243243

244244
qo_indptr = cu_seqlen[: num_seq + 1]
@@ -305,7 +305,7 @@ def flashinfer_mha_with_cache_fake(
305305
k: torch.Tensor,
306306
v: torch.Tensor,
307307
# STANDARD METADATA
308-
batch_info: torch.Tensor,
308+
batch_info_host: torch.Tensor,
309309
cu_seqlen: torch.Tensor,
310310
cu_num_pages: torch.Tensor,
311311
cache_loc: torch.Tensor,
@@ -358,7 +358,7 @@ def get_cached_attention_op(cls) -> MHACallable:
358358

359359
@classmethod
360360
def get_standard_metadata_args(cls) -> List[str]:
361-
return ["batch_info", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
361+
return ["batch_info_host", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
362362

363363
@classmethod
364364
def get_prepare_extra_metadata_info(

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _cuda_cached_causal_conv1d(
5353
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
5454
bias: Optional[torch.Tensor],
5555
# STANDARD METADATA
56-
batch_info: torch.Tensor,
56+
batch_info_host: torch.Tensor,
5757
cu_seqlen: torch.Tensor,
5858
slot_idx: torch.Tensor,
5959
use_initial_states: torch.Tensor,
@@ -80,7 +80,7 @@ def _cuda_cached_causal_conv1d(
8080
"""
8181
b, s = input.shape[:2]
8282

83-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
83+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
8484
num_seq = num_prefill + num_decode
8585
num_total_tokens = num_prefill_tokens + num_decode
8686

@@ -138,7 +138,7 @@ def _cuda_cached_causal_conv1d_fake(
138138
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
139139
bias: Optional[torch.Tensor],
140140
# STANDARD METADATA
141-
batch_info: torch.Tensor,
141+
batch_info_host: torch.Tensor,
142142
cu_seqlen: torch.Tensor,
143143
slot_idx: torch.Tensor,
144144
use_initial_states: torch.Tensor,
@@ -189,7 +189,7 @@ def get_cached_attention_op(cls) -> MHACallable:
189189

190190
@classmethod
191191
def get_standard_metadata_args(cls) -> List[str]:
192-
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
192+
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
193193

194194
@classmethod
195195
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _torch_cached_causal_conv1d(
147147
weight: torch.Tensor, # [c_out, c_in/groups, k]
148148
bias: Optional[torch.Tensor],
149149
# STANDARD METADATA
150-
batch_info: torch.Tensor,
150+
batch_info_host: torch.Tensor,
151151
seq_len: torch.Tensor,
152152
cu_seqlen: torch.Tensor,
153153
slot_idx: torch.Tensor,
@@ -174,7 +174,7 @@ def _torch_cached_causal_conv1d(
174174
num_seq = seq_len.shape[0]
175175

176176
# get cleaned up metadata
177-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
177+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
178178
num_seq = num_prefill + num_decode
179179
seq_len = seq_len[:num_seq]
180180
seq_start = cu_seqlen[:num_seq]
@@ -247,7 +247,7 @@ def _torch_cached_causal_conv1d_fake(
247247
weight: torch.Tensor, # [c_out, c_in/groups, k]
248248
bias: Optional[torch.Tensor],
249249
# STANDARD METADATA
250-
batch_info: torch.Tensor,
250+
batch_info_host: torch.Tensor,
251251
seq_len: torch.Tensor,
252252
cu_seqlen: torch.Tensor,
253253
slot_idx: torch.Tensor,
@@ -296,7 +296,7 @@ def get_cached_attention_op(cls) -> MHACallable:
296296

297297
@classmethod
298298
def get_standard_metadata_args(cls) -> List[str]:
299-
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
299+
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
300300

301301
@classmethod
302302
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _torch_cached_ssm(
121121
dt: torch.Tensor, # [b, s, num_heads]
122122
dt_bias: torch.Tensor, # [num_heads]
123123
# STANDARD METADATA
124-
batch_info: torch.Tensor,
124+
batch_info_host: torch.Tensor,
125125
seq_len: torch.Tensor,
126126
cu_seqlen: torch.Tensor,
127127
slot_idx: torch.Tensor,
@@ -145,7 +145,7 @@ def _torch_cached_ssm(
145145
num_seq = seq_len.shape[0]
146146

147147
# get cleaned up metadata
148-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
148+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
149149
num_seq = num_prefill + num_decode
150150
seq_len = seq_len[:num_seq]
151151
seq_start = cu_seqlen[:num_seq]
@@ -246,7 +246,7 @@ def _torch_cached_ssm_fake(
246246
dt: torch.Tensor, # [b, s, num_heads]
247247
dt_bias: torch.Tensor, # [num_heads]
248248
# STANDARD METADATA
249-
batch_info: torch.Tensor,
249+
batch_info_host: torch.Tensor,
250250
seq_len: torch.Tensor,
251251
cu_seqlen: torch.Tensor,
252252
slot_idx: torch.Tensor,
@@ -293,7 +293,7 @@ def get_cached_attention_op(cls) -> MHACallable:
293293

294294
@classmethod
295295
def get_standard_metadata_args(cls) -> List[str]:
296-
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
296+
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
297297

298298
@classmethod
299299
def get_cache_initializers(

0 commit comments

Comments
 (0)