Skip to content

Commit 6437756

Browse files
authored
fix: handle OOMs during KV cache estimation (NVIDIA#4690)
Signed-off-by: ixlmar <[email protected]>
1 parent 1c3091c commit 6437756

File tree

3 files changed

+33
-25
lines changed

3 files changed

+33
-25
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from ..speculative import get_spec_decoder
2121
from .config_utils import is_mla, is_nemotron_hybrid
2222
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
23+
from .llm_request import ExecutorResponse
2324
from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY,
2425
PyTorchModelEngine)
2526
from .py_executor import PyExecutor
@@ -203,16 +204,29 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
203204
req_ids = py_executor.dist.broadcast(req_ids, root=0)
204205
py_executor.is_warmup = True
205206
py_executor.start_worker()
206-
py_executor.await_responses(req_ids)
207+
try:
208+
responses = py_executor.await_responses(req_ids)
209+
for response_or_list in responses:
210+
response_list = [response_or_list] if isinstance(
211+
response_or_list, ExecutorResponse) else response_or_list
212+
for response in response_list:
213+
if response.has_error():
214+
raise RuntimeError(response.error_msg)
215+
216+
torch_peak_memory = torch.cuda.memory_stats(
217+
)["allocated_bytes.all.peak"]
218+
219+
# Clear the caching allocator before measuring the current memory usage
220+
torch.cuda.empty_cache()
221+
end, total_gpu_memory = torch.cuda.mem_get_info()
222+
torch_used_bytes = torch.cuda.memory_stats(
223+
)["allocated_bytes.all.current"]
224+
finally:
225+
py_executor.shutdown()
226+
py_executor.is_warmup = False
227+
py_executor.enable_iter_perf_stats = origin_iter_stats
228+
py_executor.set_gather_responses(False)
207229

208-
torch_peak_memory = torch.cuda.memory_stats(
209-
)["allocated_bytes.all.peak"]
210-
211-
# Clear the caching allocator before measuring the current memory usage
212-
torch.cuda.empty_cache()
213-
end, total_gpu_memory = torch.cuda.mem_get_info()
214-
torch_used_bytes = torch.cuda.memory_stats(
215-
)["allocated_bytes.all.current"]
216230
total_used_bytes = total_gpu_memory - end
217231
activation_bytes = torch_peak_memory - model_bytes
218232
extra_cost = max(total_used_bytes - torch_used_bytes, 0)
@@ -235,15 +249,6 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
235249
self._max_kv_tokens_in)
236250

237251
logger.info(f"Estimated max tokens in KV cache : {kv_cache_max_tokens}")
238-
239-
py_executor.resource_manager.resource_managers.get(
240-
"kv_cache_manager").shutdown()
241-
242-
py_executor.shutdown()
243-
py_executor.is_warmup = False
244-
py_executor.set_gather_responses(False)
245-
py_executor.enable_iter_perf_stats = origin_iter_stats
246-
247252
executor_config.kv_cache_config.max_tokens = kv_cache_max_tokens
248253

249254
def _create_kv_cache_manager(

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def _event_loop_wrapper(self):
280280
logger.error(f"Error in event loop: {e}")
281281
logger.error(traceback.format_exc())
282282
raise e
283+
finally:
284+
self._executor_loop_cleanup()
283285

284286
def start_worker(self):
285287
self.worker_lock.acquire()
@@ -833,7 +835,6 @@ def _executor_loop_pp(self):
833835
self._process_iter_stats(finished_requests,
834836
self.active_requests,
835837
previous_batch)
836-
self._executor_loop_cleanup()
837838

838839
def _executor_loop(self):
839840
torch.cuda.set_device(self.device_id)
@@ -959,8 +960,6 @@ def _executor_loop(self):
959960
iter_stats=iter_stats,
960961
iter_start_time=iter_start_time))
961962

962-
self._executor_loop_cleanup()
963-
964963
def _prepare_draft_requests(self):
965964
try:
966965
# Set draft tokens here to make the KV cache manager
@@ -1108,8 +1107,6 @@ def _executor_loop_overlap(self):
11081107
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
11091108
self._terminate_ctx_finished_requests()
11101109

1111-
self._executor_loop_cleanup()
1112-
11131110
def _process_previous_batch(self):
11141111
self._update_requests(self.previous_batch.sample_state)
11151112

@@ -1634,7 +1631,8 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
16341631
self.active_requests.remove(request)
16351632

16361633
for request in scheduled_requests.context_requests:
1637-
request.move_to_next_context_chunk()
1634+
if request.state != LlmRequestState.GENERATION_COMPLETE: # skip failed requests
1635+
request.move_to_next_context_chunk()
16381636
if request.get_context_remaining_length() == 0:
16391637
request.state = LlmRequestState.GENERATION_IN_PROGRESS
16401638

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class _ExecutorCreationStage(enum.Enum):
3131
SAMPLER = "Sampler"
3232
INIT_KV_CACHE = "Initial KV cache (temporary for KV cache size estimation)"
3333
INIT_EXTRA_RESOURCES = "Additional executor resources (temporary for KV cache size estimation)"
34+
MODEL_EXTRA = "Model resources created during usage"
3435
EXTRA_RESOURCES = "Additional executor resources"
3536
KV_CACHE = "KV cache"
3637
MODEL_ENGINE_MAIN = "Model"
@@ -86,6 +87,8 @@ def _maybe_explain_if_oom(self, e: Exception, *,
8687
"reduce max_num_tokens",
8788
_ExecutorCreationStage.EXTRA_RESOURCES:
8889
"reduce max_num_tokens",
90+
_ExecutorCreationStage.MODEL_EXTRA:
91+
"reduce max_num_tokens",
8992
}
9093

9194
msg = "\n".join([
@@ -334,7 +337,9 @@ def create_py_executor(executor_config: ExecutorConfig,
334337

335338
if estimating_kv_cache:
336339
assert kv_cache_creator is not None
337-
kv_cache_creator.estimate_max_tokens(py_executor)
340+
with mem_monitor.observe_creation_stage(
341+
_ExecutorCreationStage.MODEL_EXTRA):
342+
kv_cache_creator.estimate_max_tokens(py_executor)
338343
kv_cache_creator.teardown_managers(resources)
339344
del py_executor # free before constructing new
340345

0 commit comments

Comments
 (0)