Skip to content

Commit 876dd68

Browse files
committed
Reuse tensor
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent e56397d commit 876dd68

File tree

4 files changed

+54
-28
lines changed

4 files changed

+54
-28
lines changed

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
class Backend:
2525

2626
_custom_pass_instances: List[PatternMatcherPass] = None
27+
_graph_pool: torch.cuda.MemPool = None
2728
_graph_pool_handle: tuple[int, int] = None
2829

2930
# Following classes are used to let weakref ref the stream and eventlist objects.
@@ -60,7 +61,8 @@ def __init__(
6061
inductor_config.enable_auto_functionalized_v2 = False
6162

6263
if Backend._graph_pool_handle is None:
63-
Backend._graph_pool_handle = torch.cuda.graph_pool_handle()
64+
Backend._graph_pool = torch.cuda.MemPool()
65+
Backend._graph_pool_handle = Backend._graph_pool.id
6466

6567
self.match_count = []
6668

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -79,40 +79,55 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
7979
best_fit_block = block
8080
smallest_sufficient_size = block.buffer.numel()
8181

82+
for block in list(candidate_blocks):
83+
if not block.is_reserved:
84+
if best_fit_block is not None:
85+
if block is not best_fit_block:
86+
# Need to call del BufferBlock.buffer, otherwise memory isn't
87+
# released and OOM may happen.
88+
del block.buffer
89+
candidate_blocks.remove(block)
90+
else:
91+
del block.buffer
92+
candidate_blocks.remove(block)
93+
8294
if best_fit_block is not None:
8395
if reserve_buffer:
96+
# A suitable buffer was found, so reuse it.
8497
best_fit_block.is_reserved = True
85-
# A suitable buffer was found, so reuse it.
86-
return self._view_as(best_fit_block.buffer, tensor_shape, dtype)
87-
88-
for block in list(candidate_blocks):
89-
if not block.is_reserved:
90-
# Need to call del BufferBlock.buffer, otherwise memory isn't
91-
# released and OOM may happen.
92-
buffer_size = block.buffer.numel()
93-
del block.buffer
94-
if buffer_size >= 1024 * 1024 * 1024:
95-
torch.cuda.empty_cache()
96-
candidate_blocks.remove(block)
98+
return self._view_as(best_fit_block.buffer, tensor_shape, dtype)
99+
else:
100+
# TODO: to reuse tensors both in graph pool and normal pool.
101+
if best_fit_block.is_reserved:
102+
return self._view_as(best_fit_block.buffer, tensor_shape,
103+
dtype)
104+
else:
105+
del best_fit_block.buffer
106+
candidate_blocks.remove(best_fit_block)
107+
108+
def _create_buffer():
109+
return torch.zeros((required_memory_size, ),
110+
device='cuda',
111+
dtype=torch.uint8)
97112

98113
# No suitable buffer was found, so allocate a new one.
99114
# The new buffer is created with uint8 to represent raw bytes.
100115
new_buffer_tensor = None
101116
try:
102-
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
103-
new_buffer_tensor = torch.empty((required_memory_size, ),
104-
device='cuda',
105-
dtype=torch.uint8)
117+
new_buffer_tensor = _create_buffer()
106118
except Exception as ex:
107-
# Need to check if this is an OOM exception
119+
# Need to check if this is an OOM exception``
108120
logger.debug(
109121
f"Exception happened to create tensor from given memory pool: {str(ex)}"
110122
)
111-
# if exception happens during allocating memory from shared pool, retry
112-
# to allocate from default pool
113-
new_buffer_tensor = torch.empty((required_memory_size, ),
114-
device='cuda',
115-
dtype=torch.uint8)
123+
# if exception happens during allocating memory from default pool, retry
124+
# to allocate from shared pool. Try best to avoid fragmentation in shared pool.
125+
mem_pool = get_shared_pool()
126+
if mem_pool is not None:
127+
with torch.cuda.memory.use_mem_pool(mem_pool):
128+
new_buffer_tensor = _create_buffer()
129+
else:
130+
raise ex
116131

117132
new_block = BufferBlock(buffer=new_buffer_tensor,
118133
is_reserved=reserve_buffer)

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class CUDAGraphRunnerConfig:
6868
max_beam_width: int
6969
max_num_tokens: int
7070
spec_config: Optional[DecodingBaseConfig]
71-
cuda_graph_mem_pool: Any
71+
cuda_graph_mem_pool: torch.cuda.MemPool
7272
use_mrope: bool
7373
original_max_draft_len: int
7474
original_max_total_draft_tokens: int
@@ -107,7 +107,9 @@ def __init__(self, config: CUDAGraphRunnerConfig):
107107
self.graph_outputs: Dict[KeyType,
108108
Callable[[], Optional[torch.Tensor]]] = {}
109109
self.graph_metadata: Dict[KeyType, Dict[str, Any]] = {}
110-
self.memory_pool = config.cuda_graph_mem_pool
110+
self.memory_pool = config.cuda_graph_mem_pool if config.cuda_graph_mem_pool else torch.cuda.MemPool(
111+
)
112+
self.memory_pool_handle = self.memory_pool.id
111113
self.padding_dummy_request: Optional["Request"] = None
112114

113115
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
@@ -343,6 +345,10 @@ def _setup_spec_decoding_and_forward(key: KeyType, forward_fn: Callable,
343345
capture_inputs['attn_metadata'].use_spec_decoding = True
344346
return forward_fn(capture_inputs)
345347

348+
if self.memory_pool_handle is None or self.memory_pool is None:
349+
self.memory_pool = torch.cuda.MemPool()
350+
self.memory_pool_handle = self.memory_pool.id
351+
346352
# We have to do warm up runs to initialize PyTorch's
347353
# internal states according to the docs:
348354
# https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics
@@ -355,15 +361,14 @@ def _setup_spec_decoding_and_forward(key: KeyType, forward_fn: Callable,
355361
if postprocess_fn is not None:
356362
postprocess_fn(capture_inputs)
357363

358-
with torch.cuda.graph(graph, pool=self.memory_pool):
364+
with torch.cuda.graph(graph, pool=self.memory_pool_handle):
359365
output = _setup_spec_decoding_and_forward(
360366
key, forward_fn, capture_inputs)
361367
if postprocess_fn is not None:
362368
postprocess_fn(capture_inputs)
363369

364370
self.graphs[key] = graph
365371
self.graph_outputs[key] = make_weak_ref(output)
366-
self.memory_pool = graph.pool()
367372

368373
def replay(self, key: KeyType,
369374
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
@@ -503,6 +508,8 @@ def clear(self):
503508
self.graph_outputs.clear()
504509
self.graph_metadata.clear()
505510
self.padding_dummy_request = None
511+
del self.memory_pool_handle
512+
self.memory_pool_handle = None
506513
del self.memory_pool
507514
self.memory_pool = None
508515
torch.cuda.empty_cache()

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,9 @@ def __init__(
360360
# the model engine.
361361
self.attn_metadata = None
362362
self.iter_states = {}
363-
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
363+
364+
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool if self._torch_compile_enabled else None
365+
self._cuda_graph_mem_pool_handle = self._cuda_graph_mem_pool.id if self._cuda_graph_mem_pool else None
364366

365367
self._cuda_graph_padding_enabled = cuda_graph_padding_enabled
366368

0 commit comments

Comments
 (0)