Skip to content

Commit 81222c3

Browse files
authored
[None] Fix warning when capturing CUDA graph (#9746)
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent c1d53ee commit 81222c3

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

tensorrt_llm/_torch/speculative/drafting_loops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
from tensorrt_llm._torch.speculative.interface import SpecMetadata
2020
from tensorrt_llm._torch.speculative.spec_tree_manager import SpecTreeManager
2121

22+
# Enable capture_scalar_outputs to avoid graph breaks from Tensor.item() calls
23+
torch._dynamo.config.capture_scalar_outputs = True
24+
2225

2326
class BaseDraftingLoopWrapper(ABC, torch.nn.Module):
2427

0 commit comments

Comments
 (0)