@@ -565,6 +565,7 @@ def is_nvfp4_output_kernel_available(
565565@dataclass (kw_only = True )
566566class TrtllmAttentionMetadata (AttentionMetadata ):
567567 workspace : Optional [torch .Tensor ] = None
568+ cuda_graph_workspace : Optional [torch .Tensor ] = None
568569
569570 # TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
570571 # when beam search is enabled.
@@ -680,6 +681,14 @@ def _post_init_with_buffers(self, buffers) -> None:
680681 device = 'cuda' ,
681682 dtype = torch .int8 ,
682683 )
684+
685+ if self .cuda_graph_workspace is None :
686+ self .cuda_graph_workspace = torch .empty (
687+ (0 , ),
688+ device = 'cuda' ,
689+ dtype = torch .int8 ,
690+ )
691+
683692 if self .kv_cache_manager is not None :
684693 self .kv_cache_block_offsets = self .get_empty (
685694 buffers ,
@@ -1317,8 +1326,9 @@ def forward(
13171326 host_kv_cache_pool_pointers = metadata .host_kv_cache_pool_pointers ,
13181327 host_kv_cache_pool_mapping = metadata .host_kv_cache_pool_mapping ,
13191328 block_ids_per_seq = metadata .block_ids_per_seq ,
1320- workspace = metadata .
1321- workspace , # re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1329+ # re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1330+ workspace = metadata .workspace
1331+ if not metadata .is_cuda_graph else metadata .cuda_graph_workspace ,
13221332 cache_indirection = metadata .cache_indirection ,
13231333 kv_scale_orig_quant = self .kv_scale_orig_quant ,
13241334 kv_scale_quant_orig = self .kv_scale_quant_orig ,
0 commit comments