Skip to content

Commit 16e148d

Browse files
committed
attention i/f providing device and host argument
Signed-off-by: Lucas Liebenwein <[email protected]>
1 parent 74832a1 commit 16e148d

23 files changed

+149
-204
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 10 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -508,16 +508,17 @@ def __init__(
508508
# Create the InputBuffer that manages contiguous host and device memory
509509
# Starts on default device; use to() to move to target device
510510
self._input_buffer = InputBuffer(tensor_specs)
511+
self._available_args = set(self._input_buffer.tensor_names) | {
512+
f"{name}_host" for name in self._input_buffer.tensor_names
513+
}
511514

512515
# Initialize args_list from tensor specs
513516
self._args_list: Dict[str, List[int]] = {
514517
name: [0] * numel for name, numel, _ in tensor_specs
515518
}
516519

517520
self._active_args = ("input_ids", "position_ids")
518-
self._shapeable_args = ("input_ids", "position_ids")
519-
# Args that should be returned from host (pinned memory) instead of device in _named_args
520-
self._host_return_args = ("batch_info", "logits_gather_info")
521+
self._shapeable_args = ("input_ids", "position_ids", "input_ids_host", "position_ids_host")
521522
############################################################################################
522523

523524
# EXTRA TENSOR FIELDS ######################################################################
@@ -558,14 +559,13 @@ def _shape_for_forward(self, tnsr: torch.Tensor) -> torch.Tensor:
558559

559560
def _get_arg(self, name: str) -> torch.Tensor:
560561
"""Get the argument from the input buffer either on device or host."""
561-
if name in self._host_return_args:
562-
arg = self._input_buffer.get_host_view(name)
562+
if name.endswith("_host"):
563+
arg = self._input_buffer.get_host_view(name.replace("_host", ""))
563564
else:
564565
arg = self._input_buffer.get_view(name)
565566
return self._shape_for_forward(arg) if name in self._shapeable_args else arg
566567

567568
def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor]:
568-
# Build args dict, using host views for _host_return_args, device views otherwise
569569
args = {k: self._get_arg(k) for k in self._active_args}
570570

571571
# check other args to include
@@ -577,7 +577,7 @@ def _named_args(self, include_extra_args: bool = True) -> Dict[str, torch.Tensor
577577
@property
578578
def available_args(self) -> Set[str]:
579579
"""Return a list of available arguments."""
580-
return set(self._input_buffer.tensor_names)
580+
return self._available_args
581581

582582
@property
583583
def named_args(self) -> Dict[str, torch.Tensor]:
@@ -697,68 +697,6 @@ def _get_cache_locations_and_pages_per_sequence(
697697
pages_per_seq = [len(p) for p in page_assignments]
698698
return cache_loc_flat, pages_per_seq
699699

700-
# TODO: remove after updating all cached backends
701-
@classmethod
702-
def _get_sanitized_seq_len(
703-
cls, input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
704-
) -> torch.Tensor:
705-
"""Sanitize sequence lengths.
706-
707-
We want to cover the following scenarios with this function:
708-
709-
1. Pre-fill:
710-
input_ids: [1, s_total, ...]
711-
seq_len: [s_0, s_1, ..., s_{b-1}, 0, 0, ..., 0]
712-
---> returns [s_0, s_1, ..., s_{b-1}]
713-
2. Decode:
714-
input_ids: [b, 1, ...]
715-
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
716-
|---- b ----|--- (max_batch_size - b) ---|
717-
--> returns [1,] * b
718-
3. Decode in Cudagraph:
719-
input_ids: [b_cudagraph, 1, ...]
720-
seq_len: [1, 1, ..., 1, 0, 0, ..., ..., ..., ..., 0]
721-
|---- b ----|--- (max_batch_size - b) ---|
722-
723-
--> returns [1,] * b_cudagraph
724-
Here b <= b_cudagraph. We want to make sure that the seq_len is one-padded to
725-
b_cudagraph.
726-
727-
# TODO: I could see one possible issue with this approach in the future.
728-
# If we have b < b_cudagraph we now one-pad. However, we don't pad the cache location
729-
# information. What could happen is that the for the padded sequences the cache location
730-
# tensors point to allocated pages. This could lead to a situation where we write into
731-
# allocated cache pages polluting the cache of other sequences. Now this is not an issue
732-
# if we write the dummy sequences into unallocated cache pages... One fix could be to
733-
# pad not only the seq len but also pad the cache locations by just repeating the last
734-
# valid cache location in the batch. This would ensure that the dummy sequences just
735-
# repeats valid computation...
736-
"""
737-
_, s = input_or_position_ids.shape[:2]
738-
num_seq = cls._get_sanitized_num_sequences(input_or_position_ids, seq_len)
739-
if s > 1:
740-
return seq_len[:num_seq].clone()
741-
else:
742-
return torch.ones(num_seq, dtype=seq_len.dtype, device=seq_len.device)
743-
744-
@staticmethod
745-
def _get_sanitized_num_sequences(
746-
input_or_position_ids: torch.Tensor, seq_len: torch.Tensor
747-
) -> int:
748-
"""Get number of sequences.
749-
750-
We makes sure that this function is compatible with both torch graph capture and cudagraph.
751-
Both can be a bit temparamental when trying to extract the number of sequences from a tensor
752-
with max_batch_size or max_batch_size*max_seq_len.
753-
"""
754-
b, s = input_or_position_ids.shape[:2]
755-
if s > 1:
756-
num_seq = torch.sum(seq_len > 0)
757-
assert seq_len[num_seq:].sum() == 0, "seq_len should be zero-padded"
758-
else:
759-
num_seq = b
760-
return num_seq
761-
762700
def activate_arg(self, arg_name: str) -> bool:
763701
"""Activate a desired argument.
764702
@@ -869,7 +807,7 @@ def _store_arg(
869807
self._args_list[name] = tnsr_like.copy()
870808

871809
# Only store to buffer when the argument is active or force_copy is True
872-
if not (name in self._active_args or force_copy):
810+
if not (name in self._active_args or f"{name}_host" in self._active_args or force_copy):
873811
return
874812

875813
# Store to the InputBuffer's pinned host memory
@@ -1090,12 +1028,12 @@ def rescatter_input_ids(self, ungathered_input_ids: torch.Tensor):
10901028
def maybe_gather_and_squeeze_logits(self, logits: torch.Tensor) -> torch.Tensor:
10911029
"""Maybe gather the logits if logits have not been gathered yet."""
10921030
num_tokens = logits.shape[0] * logits.shape[1]
1093-
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info").tolist()
1031+
num_tokens_to_gather, gather_required = self._get_arg("logits_gather_info_host").tolist()
10941032
if gather_required and num_tokens_to_gather < num_tokens:
10951033
logits = torch.ops.auto_deploy.gather_logits_before_lm_head(
10961034
logits,
10971035
self._get_arg("logits_gather_indices"),
1098-
self._get_arg("logits_gather_info"),
1036+
self._get_arg("logits_gather_info_host"),
10991037
)
11001038
return logits.squeeze(int(self.is_generate))
11011039

tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def fla_cached_delta_rule(
3535
v: torch.Tensor,
3636
beta: torch.Tensor,
3737
# STANDARD METADATA
38-
batch_info: torch.Tensor,
38+
batch_info_host: torch.Tensor,
3939
cu_seqlen: torch.Tensor,
4040
slot_idx: torch.Tensor,
4141
use_initial_states: torch.Tensor,
@@ -58,7 +58,7 @@ def fla_cached_delta_rule(
5858
y = torch.empty_like(v, memory_format=torch.contiguous_format)
5959
y_flat = y.view(b * s, num_heads, -1)
6060

61-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
61+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
6262
num_seq = num_prefill + num_decode
6363

6464
# clean up metadata
@@ -120,7 +120,7 @@ def fla_cached_delta_rule_fake(
120120
v: torch.Tensor,
121121
beta: torch.Tensor,
122122
# STANDARD METADATA
123-
batch_info: torch.Tensor,
123+
batch_info_host: torch.Tensor,
124124
cu_seqlen: torch.Tensor,
125125
slot_idx: torch.Tensor,
126126
use_initial_states: torch.Tensor,
@@ -160,7 +160,7 @@ def get_cached_attention_op(cls) -> MHACallable:
160160

161161
@classmethod
162162
def get_standard_metadata_args(cls) -> List[str]:
163-
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
163+
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
164164

165165
@classmethod
166166
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ def _plan_decode(
359359
@torch.library.custom_op("auto_deploy::flashinfer_attention_prepare_metadata", mutates_args=())
360360
def prepare_flashinfer_metadata(
361361
position_ids: torch.Tensor,
362-
batch_info: torch.Tensor,
362+
batch_info_host: torch.Tensor,
363363
cu_seqlen: torch.Tensor,
364364
seq_len_with_cache: torch.Tensor,
365365
) -> List[torch.Tensor]:
@@ -370,7 +370,7 @@ def prepare_flashinfer_metadata(
370370
to understand the convention.
371371
"""
372372
# retrieve host-side metadata
373-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
373+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
374374
num_seq = num_prefill + num_decode
375375
num_tokens = num_prefill_tokens + num_decode
376376

@@ -393,7 +393,7 @@ def prepare_flashinfer_metadata(
393393
@prepare_flashinfer_metadata.register_fake
394394
def prepare_flashinfer_metadata_fake(
395395
position_ids: torch.Tensor,
396-
batch_info: torch.Tensor,
396+
batch_info_host: torch.Tensor,
397397
cu_seqlen: torch.Tensor,
398398
seq_len_with_cache: torch.Tensor,
399399
):
@@ -411,7 +411,7 @@ def flashinfer_mha_with_cache(
411411
k: torch.Tensor,
412412
v: torch.Tensor,
413413
# STANDARD METADATA
414-
batch_info: torch.Tensor,
414+
batch_info_host: torch.Tensor,
415415
cu_seqlen: torch.Tensor,
416416
cu_num_pages: torch.Tensor,
417417
cache_loc: torch.Tensor,
@@ -439,7 +439,7 @@ def flashinfer_mha_with_cache(
439439
v = v.reshape(b * s, -1, head_dim)
440440

441441
# convert to flashinfer-style metadata
442-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
442+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
443443
num_seq = num_prefill + num_decode
444444

445445
qo_indptr = cu_seqlen[: num_seq + 1]
@@ -506,7 +506,7 @@ def flashinfer_mha_with_cache_fake(
506506
k: torch.Tensor,
507507
v: torch.Tensor,
508508
# STANDARD METADATA
509-
batch_info: torch.Tensor,
509+
batch_info_host: torch.Tensor,
510510
cu_seqlen: torch.Tensor,
511511
cu_num_pages: torch.Tensor,
512512
cache_loc: torch.Tensor,
@@ -559,7 +559,7 @@ def get_cached_attention_op(cls) -> MHACallable:
559559

560560
@classmethod
561561
def get_standard_metadata_args(cls) -> List[str]:
562-
return ["batch_info", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
562+
return ["batch_info_host", "cu_seqlen", "cu_num_pages", "cache_loc", "last_page_len"]
563563

564564
@classmethod
565565
def get_prepare_extra_metadata_info(

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/cuda_backend_causal_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _cuda_cached_causal_conv1d(
5353
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
5454
bias: Optional[torch.Tensor],
5555
# STANDARD METADATA
56-
batch_info: torch.Tensor,
56+
batch_info_host: torch.Tensor,
5757
cu_seqlen: torch.Tensor,
5858
slot_idx: torch.Tensor,
5959
use_initial_states: torch.Tensor,
@@ -80,7 +80,7 @@ def _cuda_cached_causal_conv1d(
8080
"""
8181
b, s = input.shape[:2]
8282

83-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
83+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
8484
num_seq = num_prefill + num_decode
8585
num_total_tokens = num_prefill_tokens + num_decode
8686

@@ -138,7 +138,7 @@ def _cuda_cached_causal_conv1d_fake(
138138
weight: torch.Tensor, # [c_out, c_in/groups, k] but we expect depthwise use: [c_in, k]
139139
bias: Optional[torch.Tensor],
140140
# STANDARD METADATA
141-
batch_info: torch.Tensor,
141+
batch_info_host: torch.Tensor,
142142
cu_seqlen: torch.Tensor,
143143
slot_idx: torch.Tensor,
144144
use_initial_states: torch.Tensor,
@@ -189,7 +189,7 @@ def get_cached_attention_op(cls) -> MHACallable:
189189

190190
@classmethod
191191
def get_standard_metadata_args(cls) -> List[str]:
192-
return ["batch_info", "cu_seqlen", "slot_idx", "use_initial_states"]
192+
return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"]
193193

194194
@classmethod
195195
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_causal_conv.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _torch_cached_causal_conv1d(
147147
weight: torch.Tensor, # [c_out, c_in/groups, k]
148148
bias: Optional[torch.Tensor],
149149
# STANDARD METADATA
150-
batch_info: torch.Tensor,
150+
batch_info_host: torch.Tensor,
151151
seq_len: torch.Tensor,
152152
cu_seqlen: torch.Tensor,
153153
slot_idx: torch.Tensor,
@@ -174,7 +174,7 @@ def _torch_cached_causal_conv1d(
174174
num_seq = seq_len.shape[0]
175175

176176
# get cleaned up metadata
177-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
177+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
178178
num_seq = num_prefill + num_decode
179179
seq_len = seq_len[:num_seq]
180180
seq_start = cu_seqlen[:num_seq]
@@ -247,7 +247,7 @@ def _torch_cached_causal_conv1d_fake(
247247
weight: torch.Tensor, # [c_out, c_in/groups, k]
248248
bias: Optional[torch.Tensor],
249249
# STANDARD METADATA
250-
batch_info: torch.Tensor,
250+
batch_info_host: torch.Tensor,
251251
seq_len: torch.Tensor,
252252
cu_seqlen: torch.Tensor,
253253
slot_idx: torch.Tensor,
@@ -296,7 +296,7 @@ def get_cached_attention_op(cls) -> MHACallable:
296296

297297
@classmethod
298298
def get_standard_metadata_args(cls) -> List[str]:
299-
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
299+
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
300300

301301
@classmethod
302302
def get_cache_initializers(

tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def _torch_cached_ssm(
121121
dt: torch.Tensor, # [b, s, num_heads]
122122
dt_bias: torch.Tensor, # [num_heads]
123123
# STANDARD METADATA
124-
batch_info: torch.Tensor,
124+
batch_info_host: torch.Tensor,
125125
seq_len: torch.Tensor,
126126
cu_seqlen: torch.Tensor,
127127
slot_idx: torch.Tensor,
@@ -145,7 +145,7 @@ def _torch_cached_ssm(
145145
num_seq = seq_len.shape[0]
146146

147147
# get cleaned up metadata
148-
num_prefill, num_prefill_tokens, num_decode = batch_info.tolist()
148+
num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist()
149149
num_seq = num_prefill + num_decode
150150
seq_len = seq_len[:num_seq]
151151
seq_start = cu_seqlen[:num_seq]
@@ -246,7 +246,7 @@ def _torch_cached_ssm_fake(
246246
dt: torch.Tensor, # [b, s, num_heads]
247247
dt_bias: torch.Tensor, # [num_heads]
248248
# STANDARD METADATA
249-
batch_info: torch.Tensor,
249+
batch_info_host: torch.Tensor,
250250
seq_len: torch.Tensor,
251251
cu_seqlen: torch.Tensor,
252252
slot_idx: torch.Tensor,
@@ -293,7 +293,7 @@ def get_cached_attention_op(cls) -> MHACallable:
293293

294294
@classmethod
295295
def get_standard_metadata_args(cls) -> List[str]:
296-
return ["batch_info", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
296+
return ["batch_info_host", "seq_len", "cu_seqlen", "slot_idx", "use_initial_states"]
297297

298298
@classmethod
299299
def get_cache_initializers(

0 commit comments

Comments
 (0)