Skip to content

Commit 729bf8f

Browse files
committed
Use correct memory pool
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 83122bf commit 729bf8f

File tree

6 files changed

+45
-37
lines changed

6 files changed

+45
-37
lines changed

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
class Backend:
2424

2525
_custom_pass_instances: List[PatternMatcherPass] = None
26+
_graph_pool: torch.cuda.MemPool = None
2627
_graph_pool_handle: tuple[int, int] = None
2728

2829
# Following classes are used to let weakref ref the stream and eventlist objects.
@@ -58,7 +59,8 @@ def __init__(
5859
inductor_config.enable_auto_functionalized_v2 = False
5960

6061
if Backend._graph_pool_handle is None:
61-
Backend._graph_pool_handle = torch.cuda.graph_pool_handle()
62+
Backend._graph_pool = torch.cuda.MemPool()
63+
Backend._graph_pool_handle = Backend._graph_pool.id
6264

6365
self.match_count = []
6466

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,28 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
5757

5858
candidate_blocks = self.buffers.get(buffer_name, [])
5959

60-
# Find the best-fit available buffer.
61-
best_fit_block: Optional[BufferBlock] = None
62-
smallest_sufficient_size = float('inf')
63-
for block in candidate_blocks:
64-
# Skip buffers that are too small.
65-
if block.buffer.numel() < required_memory_size:
66-
continue
67-
68-
# Find the smallest buffer that is still large enough (best-fit).
69-
if block.buffer.numel() < smallest_sufficient_size:
70-
# Use reserved block if find one.
71-
if best_fit_block is not None and best_fit_block.is_reserved and not block.is_reserved:
60+
if reserve_buffer:
61+
# Find the best-fit available buffer.
62+
best_fit_block: Optional[BufferBlock] = None
63+
smallest_sufficient_size = float('inf')
64+
for block in candidate_blocks:
65+
# Skip buffers that are too small.
66+
if block.buffer.numel() < required_memory_size:
7267
continue
7368

74-
best_fit_block = block
75-
smallest_sufficient_size = block.buffer.numel()
69+
# Find the smallest buffer that is still large enough (best-fit).
70+
if block.buffer.numel() < smallest_sufficient_size:
71+
# Use reserved block if find one.
72+
if best_fit_block is not None and best_fit_block.is_reserved and not block.is_reserved:
73+
continue
7674

77-
if reserve_buffer and best_fit_block is not None:
78-
# A suitable buffer was found, so reuse it.
79-
best_fit_block.is_reserved = True
80-
return self._view_as(best_fit_block.buffer, tensor_shape, dtype)
75+
best_fit_block = block
76+
smallest_sufficient_size = block.buffer.numel()
77+
78+
if best_fit_block is not None:
79+
# A suitable buffer was found, so reuse it.
80+
best_fit_block.is_reserved = True
81+
return self._view_as(best_fit_block.buffer, tensor_shape, dtype)
8182

8283
for block in list(candidate_blocks):
8384
if not block.is_reserved:
@@ -88,22 +89,27 @@ def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
8889

8990
# No suitable buffer was found, so allocate a new one.
9091
# The new buffer is created with uint8 to represent raw bytes.
92+
def _create_buffer():
93+
return torch.zeros((required_memory_size, ),
94+
device='cuda',
95+
dtype=torch.uint8)
96+
9197
new_buffer_tensor = None
9298
try:
93-
with torch.cuda.memory.use_mem_pool(get_shared_pool()):
94-
new_buffer_tensor = torch.zeros((required_memory_size, ),
95-
device='cuda',
96-
dtype=torch.uint8)
99+
mem_pool = get_shared_pool()
100+
if mem_pool is not None:
101+
with torch.cuda.memory.use_mem_pool():
102+
new_buffer_tensor = _create_buffer()
103+
else:
104+
new_buffer_tensor = _create_buffer()
97105
except Exception as ex:
98106
# Need to check if this is an OOM exception
99107
logger.debug(
100108
f"Exception happened to create tensor from given memory pool: {str(ex)}"
101109
)
102110
# if exception happens during allocating memory from shared pool, retry
103111
# to allocate from default pool
104-
new_buffer_tensor = torch.zeros((required_memory_size, ),
105-
device='cuda',
106-
dtype=torch.uint8)
112+
new_buffer_tensor = _create_buffer()
107113

108114
new_block = BufferBlock(buffer=new_buffer_tensor,
109115
is_reserved=reserve_buffer)

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class CUDAGraphRunnerConfig:
6161
max_beam_width: int
6262
max_num_tokens: int
6363
spec_config: Optional[DecodingBaseConfig]
64-
cuda_graph_mem_pool: Any
64+
cuda_graph_mem_pool: torch.cuda.MemPool
6565
use_mrope: bool
6666
original_max_draft_len: int
6767
original_max_total_draft_tokens: int
@@ -98,7 +98,9 @@ def __init__(self, config: CUDAGraphRunnerConfig):
9898
self.graph_outputs: Dict[Tuple[int, int, int],
9999
Callable[[], Optional[torch.Tensor]]] = {}
100100
self.graph_metadata: Dict[Tuple[int, int, int], Dict[str, Any]] = {}
101-
self.memory_pool = config.cuda_graph_mem_pool
101+
self.memory_pool = config.cuda_graph_mem_pool if config.cuda_graph_mem_pool else torch.cuda.MemPool(
102+
)
103+
self.memory_pool_handle = self.memory_pool.id
102104
self.padding_dummy_request: Optional["Request"] = None
103105

104106
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
@@ -293,15 +295,14 @@ def _setup_spec_decoding_and_forward(key: Tuple[int, int, int],
293295
if postprocess_fn is not None:
294296
postprocess_fn(capture_inputs)
295297

296-
with torch.cuda.graph(graph, pool=self.memory_pool):
298+
with torch.cuda.graph(graph, pool=self.memory_pool_handle):
297299
output = _setup_spec_decoding_and_forward(
298300
key, forward_fn, capture_inputs)
299301
if postprocess_fn is not None:
300302
postprocess_fn(capture_inputs)
301303

302304
self.graphs[key] = graph
303305
self.graph_outputs[key] = make_weak_ref(output)
304-
self.memory_pool = graph.pool()
305306

306307
def replay(self, key: Tuple[int, int, int],
307308
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
@@ -427,6 +428,6 @@ def clear(self):
427428
self.graph_outputs.clear()
428429
self.graph_metadata.clear()
429430
self.padding_dummy_request = None
430-
del self.memory_pool
431-
self.memory_pool = None
431+
del self.memory_pool_handle
432+
self.memory_pool_handle = None
432433
torch.cuda.empty_cache()

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,8 @@ def __init__(
340340
# the model engine.
341341
self.attn_metadata = None
342342
self.iter_states = {}
343-
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None
343+
self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool if self._torch_compile_enabled else None
344+
self._cuda_graph_mem_pool_handle = self._cuda_graph_mem_pool.id if self._cuda_graph_mem_pool else None
344345

345346
self._cuda_graph_padding_enabled = cuda_graph_padding_enabled
346347

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ l0_dgx_b200:
4343
- accuracy/test_llm_api_pytorch.py::TestQwen3NextThinking::test_auto_dtype[tp4ep4]
4444
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
4545
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-trtllm-auto]
46-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto]
46+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-cutlass-auto] ISOLATION
4747
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-triton-auto]
4848
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-auto]
4949
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-trtllm-fp8]
@@ -187,7 +187,7 @@ l0_dgx_b200:
187187
- accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_fp8_chunked_prefill[tp4ep4-cuda_graph=True]
188188
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=0]
189189
- accuracy/test_disaggregated_serving.py::TestQwen3_30B_A3B::test_mixed_ctx_gen_model[ctxpp2gentp2]
190-
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto]
190+
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-cutlass-auto] ISOLATION
191191
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-triton-auto]
192192
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-auto]
193193
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-trtllm-fp8]

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,6 @@ accuracy/test_llm_api_pytorch.py::TestNemotronH_47B_Base::test_auto_dtype[tp8ep4
374374
accuracy/test_llm_api_pytorch.py::TestNemotronH_47B_Base::test_reasoning_fp8_prequantized[tp8ep8-cuda_graph=True] SKIP (https://nvbugs/5640697)
375375
accuracy/test_llm_api_pytorch.py::TestQwQ_32B::test_auto_dtype_tp4 SKIP (https://nvbugs/5640697)
376376
accuracy/test_llm_api_pytorch_multimodal.py::TestLlava_V1_6_Mistral_7B::test_auto_dtype SKIP (https://nvbugs/5644187)
377-
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] SKIP (https://nvbugs/5644632)
378-
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] SKIP (https://nvbugs/5644632)
379377
test_e2e.py::test_ptp_quickstart_multimodal_kv_cache_reuse[phi4-multimodal-instruct-multimodals/Phi-4-multimodal-instruct-0.8-image] SKIP (https://nvbugs/5644190)
380378
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-True] SKIP (https://nvbugs/5648560)
381379
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False] SKIP (https://nvbugs/5648560)

0 commit comments

Comments
 (0)