-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-9676][fix] Fix mamba_cache_manager when enabling cuda_graph_padding and let test cover this case #9873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
08d4e36
ea1c502
9edd777
6181cf6
e876a20
b48a00d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,6 +110,14 @@ def __init__( | |
| device=device, | ||
| dtype=torch.int32) | ||
|
|
||
| # only for `reorder_state_indices_when_padding_requests` | ||
| self.request_mask = torch.ones(max_batch_size, | ||
| dtype=torch.bool, | ||
| device=device) | ||
| self.state_indices_arange = torch.arange(max_batch_size, | ||
| dtype=torch.int32, | ||
| device=device) | ||
|
|
||
| def _prepare_mamba_cache_blocks(self, request_ids: List[int]): | ||
| state_indices = [] | ||
| for r in request_ids: | ||
|
|
@@ -127,6 +135,18 @@ def _prepare_mamba_cache_blocks(self, request_ids: List[int]): | |
| state_indices, dtype=torch.int32, pin_memory=True), | ||
| non_blocking=True) | ||
|
|
||
| # When there exists padded requests, the state indices should not be repeated. | ||
| def reorder_state_indices_when_padding_requests(self, request_size, | ||
| padding_size): | ||
| if padding_size == 0: | ||
| return | ||
|
|
||
| self.request_mask[:] = True | ||
| self.request_mask[self.state_indices[:request_size]] = False | ||
|
||
| self.state_indices[request_size:request_size + | ||
| padding_size] = self.state_indices_arange[ | ||
| self.request_mask][:padding_size] | ||
|
|
||
| def prepare_resources(self, scheduled_batch: ScheduledRequests): | ||
| context_ids = [ | ||
| i.py_request_id for i in scheduled_batch.context_requests | ||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.