Skip to content

Commit 51493c1

Browse files
JunyiXu-nvmikeiovine
authored andcommitted
[https://nvbugs/5606268][fix] Separate cuda graph workspace to prevent IMA (#8685)
Signed-off-by: Junyi Xu <[email protected]> Signed-off-by: Mike Iovine <[email protected]>
1 parent 5c42706 commit 51493c1

File tree

1 file changed

+12
-2
lines changed
  • tensorrt_llm/_torch/attention_backend

1 file changed

+12
-2
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def is_nvfp4_output_kernel_available(
565565
@dataclass(kw_only=True)
566566
class TrtllmAttentionMetadata(AttentionMetadata):
567567
workspace: Optional[torch.Tensor] = None
568+
cuda_graph_workspace: Optional[torch.Tensor] = None
568569

569570
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
570571
# when beam search is enabled.
@@ -680,6 +681,14 @@ def _post_init_with_buffers(self, buffers) -> None:
680681
device='cuda',
681682
dtype=torch.int8,
682683
)
684+
685+
if self.cuda_graph_workspace is None:
686+
self.cuda_graph_workspace = torch.empty(
687+
(0, ),
688+
device='cuda',
689+
dtype=torch.int8,
690+
)
691+
683692
if self.kv_cache_manager is not None:
684693
self.kv_cache_block_offsets = self.get_empty(
685694
buffers,
@@ -1317,8 +1326,9 @@ def forward(
13171326
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
13181327
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
13191328
block_ids_per_seq=metadata.block_ids_per_seq,
1320-
workspace=metadata.
1321-
workspace, # re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1329+
# re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1330+
workspace=metadata.workspace
1331+
if not metadata.is_cuda_graph else metadata.cuda_graph_workspace,
13221332
cache_indirection=metadata.cache_indirection,
13231333
kv_scale_orig_quant=self.kv_scale_orig_quant,
13241334
kv_scale_quant_orig=self.kv_scale_quant_orig,

0 commit comments

Comments
 (0)