-
Notifications
You must be signed in to change notification settings - Fork 2k
[TRTLLM-9676][fix] Fix mamba_cache_manager when enabling cuda_graph_padding and let test cover this case #9873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
08d4e36
ea1c502
9edd777
6181cf6
e876a20
b48a00d
1617c14
2f696fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -110,6 +110,14 @@ def __init__( | |
| device=device, | ||
| dtype=torch.int32) | ||
|
|
||
| # only for `reorder_state_indices_when_padding_requests` | ||
| self.request_mask = torch.ones(max_batch_size, | ||
| dtype=torch.bool, | ||
| device=device) | ||
| self.state_indices_arange = torch.arange(max_batch_size, | ||
| dtype=torch.int32, | ||
| device=device) | ||
|
|
||
| def _prepare_mamba_cache_blocks(self, request_ids: List[int]): | ||
| state_indices = [] | ||
| for r in request_ids: | ||
|
|
@@ -126,6 +134,18 @@ def _prepare_mamba_cache_blocks(self, request_ids: List[int]): | |
| self.state_indices[:len(state_indices)] = torch.as_tensor( | ||
| state_indices, dtype=torch.int32, device=self.ssm_states.device) | ||
|
|
||
| # When there exists padded requests, the state indices should not be repeated. | ||
| def reorder_state_indices_when_padding_requests(self, request_size, | ||
| padding_size): | ||
| if padding_size == 0: | ||
| return | ||
|
|
||
| self.request_mask[:] = True | ||
| self.request_mask[self.state_indices[:request_size]] = False | ||
|
||
| self.state_indices[request_size:request_size + | ||
| padding_size] = self.state_indices_arange[ | ||
| self.request_mask][:padding_size] | ||
|
|
||
| def prepare_resources(self, scheduled_batch: ScheduledRequests): | ||
| context_ids = [ | ||
| i.py_request_id for i in scheduled_batch.context_requests | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4629,10 +4629,12 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, | |
| model_path = f"{self.MODEL_PATH}/Qwen3-Next-80B-A3B-Instruct" | ||
| kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6, | ||
| enable_block_reuse=False) | ||
| pytorch_config = dict(disable_overlap_scheduler=not overlap_scheduler, | ||
| cuda_graph_config=CudaGraphConfig( | ||
| max_batch_size=512, enable_padding=True) | ||
| if cuda_graph else None) | ||
| pytorch_config = dict( | ||
| disable_overlap_scheduler=not overlap_scheduler, | ||
| cuda_graph_config=CudaGraphConfig( | ||
| enable_padding=True, | ||
| batch_sizes=[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use max_batch_size? Try to reduce the some CUDA Graph overhead?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This bug relies on specific cuda graph batchsize setting. But it makes no sense here if the default setting is the same when setting |
||
| if cuda_graph else None) | ||
|
|
||
| with LLM( | ||
| model_path, | ||
|
|
@@ -4647,6 +4649,7 @@ def test_bf16_4gpu(self, tp_size, pp_size, ep_size, cuda_graph, | |
| task.evaluate(llm) | ||
| mocker.patch.object(GSM8K, "MAX_OUTPUT_LEN", | ||
| self.GSM8K_MAX_OUTPUT_LEN) | ||
| mocker.patch.object(GSM8K, "NUM_SAMPLES", 1319) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the default NUM_SAMPLES of GSM8K is 1319, do we still need to specify it explicitly?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to make sure here is 1319, even if someone else change the default sample numbers. The total samples run is important to assert the correctness of cuda graph padding.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Either we call GSM8K.NUM_SAMPLES directly or we add a assertion here. |
||
| task = GSM8K(self.MODEL_NAME) | ||
| task.evaluate(llm) | ||
|
|
||
|
|
||


Uh oh!
There was an error while loading. Please reload this page.