Skip to content

Commit 8ddb430

Browse files
committed
removing todo
Signed-off-by: Ludwig Schneider <[email protected]>
1 parent 6d2e145 commit 8ddb430

File tree

2 files changed

+127
-1
lines changed

2 files changed

+127
-1
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Workaround for PyTorch tensor lifetime issue during CUDA graph capture.
3+
4+
This module provides a mechanism to keep tensors alive during CUDA graph capture
5+
by storing Python references to them. This prevents premature destruction of tensors
6+
returned from custom C++ operators that have custom deleters.
7+
8+
Issue: During CUDA graph capture, tensors returned from custom C++ operators
9+
may have use_count=1, causing them to be destroyed immediately before PyTorch
10+
binding can increment the reference count. This causes custom deleters to be
11+
called prematurely, releasing buffers while NCCL operations are still using them.
12+
13+
Workaround: Store references to output tensors during graph capture, ensuring
14+
they stay alive until graph execution completes.
15+
"""
16+
17+
import threading
18+
from typing import List, Union
19+
20+
import torch
21+
22+
23+
class TensorLifetimeRegistry:
24+
"""
25+
Thread-safe registry to store tensor references during CUDA graph capture.
26+
27+
This ensures tensors returned from custom operators stay alive during graph
28+
capture and execution, preventing premature deleter calls.
29+
"""
30+
31+
def __init__(self):
32+
self._lock = threading.Lock()
33+
# Store references per thread to handle multi-threaded scenarios
34+
self._thread_local = threading.local()
35+
36+
def _get_storage(self) -> List[List[torch.Tensor]]:
37+
"""Get thread-local storage for tensor references."""
38+
if not hasattr(self._thread_local, "tensor_refs"):
39+
self._thread_local.tensor_refs = []
40+
return self._thread_local.tensor_refs
41+
42+
def register_tensors(self, tensors: Union[torch.Tensor, List[torch.Tensor], tuple]):
43+
"""
44+
Register tensor(s) to keep them alive during graph capture.
45+
46+
Args:
47+
tensors: Single tensor, list of tensors, or tuple of tensors to register
48+
"""
49+
with self._lock:
50+
storage = self._get_storage()
51+
52+
# Convert to list of tensors
53+
if isinstance(tensors, torch.Tensor):
54+
tensor_list = [tensors]
55+
elif isinstance(tensors, (list, tuple)):
56+
tensor_list = [t for t in tensors if isinstance(t, torch.Tensor)]
57+
else:
58+
return # Not a tensor, ignore
59+
60+
# Only register if we're in graph capture
61+
if self.is_capturing():
62+
storage.append(tensor_list)
63+
print(
64+
f"[TensorLifetimeRegistry] Registered {len(tensor_list)} tensor(s) "
65+
f"during graph capture (total batches: {len(storage)})"
66+
)
67+
68+
def is_capturing(self) -> bool:
69+
"""
70+
Check if we're currently in CUDA graph capture.
71+
72+
Returns:
73+
True if currently capturing a CUDA graph, False otherwise
74+
"""
75+
try:
76+
# Check if any stream is currently capturing
77+
# torch.cuda.is_current_stream_capturing() checks the current stream
78+
return torch.cuda.is_current_stream_capturing()
79+
except (AttributeError, RuntimeError):
80+
# Fallback: if the function doesn't exist or there's an error, assume not capturing
81+
return False
82+
83+
def clear(self):
84+
"""Clear all registered tensor references (call after graph execution completes)."""
85+
with self._lock:
86+
if hasattr(self._thread_local, "tensor_refs"):
87+
count = sum(len(batch) for batch in self._thread_local.tensor_refs)
88+
self._thread_local.tensor_refs.clear()
89+
print(f"[TensorLifetimeRegistry] Cleared {count} tensor reference(s)")
90+
91+
def get_registered_count(self) -> int:
92+
"""Get the number of registered tensor batches."""
93+
with self._lock:
94+
if hasattr(self._thread_local, "tensor_refs"):
95+
return len(self._thread_local.tensor_refs)
96+
return 0
97+
98+
99+
# Global singleton instance
100+
_tensor_registry = TensorLifetimeRegistry()
101+
102+
103+
def register_tensor_references(tensors: Union[torch.Tensor, List[torch.Tensor], tuple]):
104+
"""
105+
Register tensor(s) to keep them alive during CUDA graph capture.
106+
107+
This is a convenience function that uses the global registry.
108+
109+
Args:
110+
tensors: Single tensor, list of tensors, or tuple of tensors to register
111+
"""
112+
_tensor_registry.register_tensors(tensors)
113+
114+
115+
def clear_tensor_references():
116+
"""Clear all registered tensor references."""
117+
_tensor_registry.clear()
118+
119+
120+
def is_graph_capturing() -> bool:
121+
"""Check if we're currently in CUDA graph capture."""
122+
return _tensor_registry.is_capturing()
123+
124+
125+
def get_registered_count() -> int:
126+
"""Get the number of registered tensor batches."""
127+
return _tensor_registry.get_registered_count()

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1718,7 +1718,6 @@ def forward(
17181718
) -> torch.Tensor:
17191719
input, residual, norm_weight, scale, bias, workspace = inputs
17201720
if tactic == -1:
1721-
# TODO: Use NCCL instead of NCCL_SYMMETRIC to avoid hanging during tuning process
17221721
tactic = AllReduceStrategy.NCCL_SYMMETRIC.value
17231722

17241723
return torch.ops.trtllm.allreduce(

0 commit comments

Comments
 (0)