2020from ..speculative import get_spec_decoder
2121from .config_utils import is_mla , is_nemotron_hybrid
2222from .kv_cache_transceiver import AttentionTypeCpp , create_kv_cache_transceiver
23+ from .llm_request import ExecutorResponse
2324from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY , KV_CACHE_MANAGER_KEY ,
2425 PyTorchModelEngine )
2526from .py_executor import PyExecutor
@@ -203,16 +204,29 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
203204 req_ids = py_executor .dist .broadcast (req_ids , root = 0 )
204205 py_executor .is_warmup = True
205206 py_executor .start_worker ()
206- py_executor .await_responses (req_ids )
207+ try :
208+ responses = py_executor .await_responses (req_ids )
209+ for response_or_list in responses :
210+ response_list = [response_or_list ] if isinstance (
211+ response_or_list , ExecutorResponse ) else response_or_list
212+ for response in response_list :
213+ if response .has_error ():
214+ raise RuntimeError (response .error_msg )
215+
216+ torch_peak_memory = torch .cuda .memory_stats (
217+ )["allocated_bytes.all.peak" ]
218+
219+ # Clear the caching allocator before measuring the current memory usage
220+ torch .cuda .empty_cache ()
221+ end , total_gpu_memory = torch .cuda .mem_get_info ()
222+ torch_used_bytes = torch .cuda .memory_stats (
223+ )["allocated_bytes.all.current" ]
224+ finally :
225+ py_executor .shutdown ()
226+ py_executor .is_warmup = False
227+ py_executor .enable_iter_perf_stats = origin_iter_stats
228+ py_executor .set_gather_responses (False )
207229
208- torch_peak_memory = torch .cuda .memory_stats (
209- )["allocated_bytes.all.peak" ]
210-
211- # Clear the caching allocator before measuring the current memory usage
212- torch .cuda .empty_cache ()
213- end , total_gpu_memory = torch .cuda .mem_get_info ()
214- torch_used_bytes = torch .cuda .memory_stats (
215- )["allocated_bytes.all.current" ]
216230 total_used_bytes = total_gpu_memory - end
217231 activation_bytes = torch_peak_memory - model_bytes
218232 extra_cost = max (total_used_bytes - torch_used_bytes , 0 )
@@ -235,15 +249,6 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
235249 self ._max_kv_tokens_in )
236250
237251 logger .info (f"Estimated max tokens in KV cache : { kv_cache_max_tokens } " )
238-
239- py_executor .resource_manager .resource_managers .get (
240- "kv_cache_manager" ).shutdown ()
241-
242- py_executor .shutdown ()
243- py_executor .is_warmup = False
244- py_executor .set_gather_responses (False )
245- py_executor .enable_iter_perf_stats = origin_iter_stats
246-
247252 executor_config .kv_cache_config .max_tokens = kv_cache_max_tokens
248253
249254 def _create_kv_cache_manager (
0 commit comments