Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ transforms:
attn_backend: MultiHeadLatentAttention
insert_cached_ssm_attention:
stage: cache_init
attn_backend: torch_ssm
attn_backend: triton_ssm
insert_cached_causal_conv:
stage: cache_init
attn_backend: torch_causal_conv
attn_backend: cuda_causal_conv
initialize_cache:
stage: cache_init
resize_kv_cache:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
class CacheConfig:
"""A dataclass to hold information how to configure the cache."""

# dtype of the cache
dtype: Optional[torch.dtype] = None


Expand Down Expand Up @@ -522,6 +523,7 @@ def set_example_sequence(

# vanilla slot indices
slot_idx = list(range(len(input_ids)))
# breakpoint()

self.nest_sequences(
input_ids,
Expand All @@ -537,6 +539,9 @@ def set_max_num_tokens_sample(self) -> None:
# TODO (lucaslie): understand what this implies for extra arguments
seq_len = self.max_num_tokens // self.max_batch_size
input_ids = torch.ones(self.max_batch_size, seq_len, dtype=torch.int).tolist()
print(
f"setting max_num_tokens_sample: {self.max_num_tokens=}, {self.max_batch_size=}, {seq_len=}"
)
self.set_example_sequence(input_ids)

def set_generate_only_batch(self) -> None:
Expand Down Expand Up @@ -581,6 +586,10 @@ def _store_arg(
# pin the memory on the host
tnsr_host = torch.tensor(tnsr_like, dtype=tnsr_device.dtype, pin_memory=True)

if tnsr_device.numel() < tnsr_host.numel():
print("WARNING: tnsr_device.numel() < tnsr_like.numel()")
print(f"{name=}, {tnsr_device.numel()=}, {tnsr_host.numel()=}")
tnsr_device.resize_(tnsr_host.numel())
Comment on lines +589 to +592
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious where this is necessary?

Comment on lines +589 to +592
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucaslie : FYI, this is the WAR I have to get resize functionality working again on the feature branch. Without this llama3.1 + cache_resize is broken

# reset/copy to the device in a non-blocking fashion
if reset:
tnsr_device.zero_()
Expand Down
Loading
Loading