Skip to content

Commit dd29063

Browse files
authored
[feat] Add llm args to tune python gc threshold (NVIDIA#5141)
Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
1 parent 03f1a6a commit dd29063

File tree

10 files changed

+95
-33
lines changed

10 files changed

+95
-33
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,8 @@ def create_py_executor_instance(
384384
draft_model_engine,
385385
start_worker,
386386
sampler,
387-
lora_config: Optional[LoraConfig] = None) -> PyExecutor:
387+
lora_config: Optional[LoraConfig] = None,
388+
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
388389
kv_cache_manager = resources.get(KV_CACHE_MANAGER_KEY, None)
389390

390391
spec_config = model_engine.spec_config
@@ -496,19 +497,21 @@ def create_py_executor_instance(
496497
kv_cache_transceiver = create_kv_cache_transceiver(
497498
mapping, kv_cache_manager, attention_type, cache_transceiver_config)
498499

499-
return PyExecutor(resource_manager,
500-
scheduler,
501-
model_engine=model_engine,
502-
sampler=sampler,
503-
dist=dist,
504-
disable_overlap_scheduler=pytorch_backend_config.
505-
disable_overlap_scheduler,
506-
max_batch_size=executor_config.max_batch_size,
507-
max_draft_tokens=spec_config.max_draft_tokens
508-
if spec_config is not None else 0,
509-
kv_cache_transceiver=kv_cache_transceiver,
510-
draft_model_engine=draft_model_engine,
511-
start_worker=start_worker)
500+
return PyExecutor(
501+
resource_manager,
502+
scheduler,
503+
model_engine=model_engine,
504+
sampler=sampler,
505+
dist=dist,
506+
disable_overlap_scheduler=pytorch_backend_config.
507+
disable_overlap_scheduler,
508+
max_batch_size=executor_config.max_batch_size,
509+
max_draft_tokens=spec_config.max_draft_tokens
510+
if spec_config is not None else 0,
511+
kv_cache_transceiver=kv_cache_transceiver,
512+
draft_model_engine=draft_model_engine,
513+
start_worker=start_worker,
514+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
512515

513516

514517
def instantiate_sampler(model_engine: PyTorchModelEngine,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
import torch
1818

19-
from tensorrt_llm._utils import (global_mpi_rank, is_trace_enabled, nvtx_range,
20-
trace_func)
19+
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
20+
is_trace_enabled, nvtx_range, trace_func)
2121
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
2222
FinishReason, InflightBatchingStats,
2323
IterationStats, KvCacheStats,
@@ -171,6 +171,7 @@ def __init__(self,
171171
max_draft_tokens: int = 0,
172172
kv_cache_transceiver: KvCacheTransceiver = None,
173173
draft_model_engine: Optional[ModelEngine] = None,
174+
garbage_collection_gen0_threshold: Optional[int] = None,
174175
start_worker: bool = True):
175176
super(PyExecutor, self).__init__()
176177
self.device_id = torch.cuda.current_device()
@@ -268,14 +269,18 @@ def __init__(self,
268269
"Drafting is not supported for selected executor loop. "
269270
"Please disable disagg/pipeline parallelism/overlap scheduler.")
270271

272+
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
273+
271274
self.worker_started = False
272275
self.worker_lock = threading.Lock()
273276
if start_worker:
274277
self.start_worker()
275278

276279
def _event_loop_wrapper(self):
277280
try:
278-
self.event_loop()
281+
with customized_gc_thresholds(
282+
self.garbage_collection_gen0_threshold):
283+
self.event_loop()
279284
except Exception as e:
280285
logger.error(f"Error in event loop: {e}")
281286
logger.error(traceback.format_exc())

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,12 @@ def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
176176
return mapping
177177

178178

179-
def create_py_executor(executor_config: ExecutorConfig,
180-
checkpoint_dir: str = None,
181-
engine_dir: str = None,
182-
lora_config: Optional[LoraConfig] = None) -> PyExecutor:
179+
def create_py_executor(
180+
executor_config: ExecutorConfig,
181+
checkpoint_dir: str = None,
182+
engine_dir: str = None,
183+
lora_config: Optional[LoraConfig] = None,
184+
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
183185
_mangle_executor_config(executor_config)
184186
pytorch_backend_config = executor_config.pytorch_backend_config
185187

@@ -334,7 +336,7 @@ def create_py_executor(executor_config: ExecutorConfig,
334336
py_executor = create_py_executor_instance(
335337
dist, resources, mapping, pytorch_backend_config, executor_config,
336338
ctx_chunk_config, model_engine, draft_model_engine, False, sampler,
337-
lora_config)
339+
lora_config, garbage_collection_gen0_threshold)
338340

339341
if estimating_kv_cache:
340342
assert kv_cache_creator is not None
@@ -365,7 +367,8 @@ def create_py_executor(executor_config: ExecutorConfig,
365367
py_executor = create_py_executor_instance(
366368
dist, resources, mapping, pytorch_backend_config,
367369
executor_config, ctx_chunk_config, model_engine,
368-
draft_model_engine, False, sampler, lora_config)
370+
draft_model_engine, False, sampler, lora_config,
371+
garbage_collection_gen0_threshold)
369372

370373
py_executor.start_worker()
371374
return py_executor

tensorrt_llm/_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,6 +781,26 @@ def __getitem__(self, index):
781781
return self.objs[index]
782782

783783

784+
PYTHON_DEFAULT_GC_THRESHOLDS = gc.get_threshold()
785+
786+
787+
@contextmanager
788+
def customized_gc_thresholds(gen0_threshold: Optional[int] = None):
789+
try:
790+
if gen0_threshold:
791+
gc.set_threshold(gen0_threshold)
792+
logger.debug(
793+
f'Set Python GC threshold to customized value: {gen0_threshold}'
794+
)
795+
yield
796+
finally:
797+
if gen0_threshold:
798+
gc.set_threshold(*PYTHON_DEFAULT_GC_THRESHOLDS)
799+
logger.debug(
800+
f'Reset Python GC thresholds to default value: {PYTHON_DEFAULT_GC_THRESHOLDS}'
801+
)
802+
803+
784804
@contextmanager
785805
def _null_context_manager():
786806
yield

tensorrt_llm/executor/executor.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,7 @@ def create(
350350
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
351351
is_llm_executor: Optional[bool] = None,
352352
lora_config: Optional[LoraConfig] = None,
353+
garbage_collection_gen0_threshold: Optional[int] = None,
353354
) -> Union["GenerationExecutorProxy", "GenerationExecutorWorker"]:
354355
# local imports to avoid cyclic importing
355356
from .proxy import GenerationExecutorProxy
@@ -393,7 +394,9 @@ def create(
393394
model_world_size=model_world_size,
394395
mpi_session=mpi_session,
395396
postproc_worker_config=postproc_worker_config,
396-
is_llm_executor=is_llm_executor)
397+
is_llm_executor=is_llm_executor,
398+
garbage_collection_gen0_threshold=
399+
garbage_collection_gen0_threshold)
397400

398401
# WAR: For the performance of gathering logits, we use single process worker
399402
# for TP1 to avoid the large overhead of IPC.
@@ -404,7 +407,9 @@ def create(
404407
"Using single process worker for TP1, this may hurt streaming generation performance."
405408
)
406409
return GenerationExecutorWorker(**worker_kwargs,
407-
is_llm_executor=is_llm_executor)
410+
is_llm_executor=is_llm_executor,
411+
garbage_collection_gen0_threshold=
412+
garbage_collection_gen0_threshold)
408413

409414
# For single-gpu case:
410415
# Partition the workload to multiple process for streaming performance.
@@ -416,7 +421,9 @@ def create(
416421
model_world_size=model_world_size,
417422
mpi_session=None, # use mpi4py
418423
postproc_worker_config=postproc_worker_config,
419-
is_llm_executor=is_llm_executor)
424+
is_llm_executor=is_llm_executor,
425+
garbage_collection_gen0_threshold=
426+
garbage_collection_gen0_threshold)
420427
else:
421428
ctx = multiprocessing.get_context("spawn")
422429
# The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
@@ -427,7 +434,9 @@ def create(
427434
model_world_size=model_world_size,
428435
mpi_session=mpi_session,
429436
postproc_worker_config=postproc_worker_config,
430-
is_llm_executor=is_llm_executor)
437+
is_llm_executor=is_llm_executor,
438+
garbage_collection_gen0_threshold=
439+
garbage_collection_gen0_threshold)
431440

432441
def wait_first_completed(
433442
self, futures: List[GenerationResult]

tensorrt_llm/executor/proxy.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from tensorrt_llm.logger import logger
1313

14-
from .._utils import mpi_rank, nvtx_range_debug
14+
from .._utils import customized_gc_thresholds, mpi_rank, nvtx_range_debug
1515
from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
1616
RemoteMpiCommSessionClient)
1717
from ..llmapi.tracer import enable_llm_tracer, get_tracer, global_tracer
@@ -44,6 +44,7 @@ def __init__(
4444
worker_cls: type = GenerationExecutorWorker,
4545
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
4646
is_llm_executor: Optional[bool] = None,
47+
garbage_collection_gen0_threshold: Optional[int] = None,
4748
) -> None:
4849
postproc_worker_config = postproc_worker_config or PostprocWorkerConfig(
4950
)
@@ -86,10 +87,14 @@ def __init__(
8687

8788
self.model_world_size = model_world_size
8889

90+
self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold
91+
8992
worker_kwargs = dict(**worker_kwargs,
9093
worker_queues=self._setup_queues(),
9194
postproc_worker_config=postproc_worker_config,
92-
is_llm_executor=False)
95+
is_llm_executor=False,
96+
garbage_collection_gen0_threshold=self.
97+
garbage_collection_gen0_threshold)
9398

9499
if "log_level" not in worker_kwargs:
95100
worker_kwargs["log_level"] = logger.level
@@ -152,8 +157,9 @@ def abort_request(self, request_id: int) -> None:
152157
def dispatch_result_task(self) -> bool:
153158
# TODO[chunweiy]: convert the dispatch_result_task to async, that should
154159
# benefit from zmq.asyncio.Context
155-
if (res := self.result_queue.get()) is None:
156-
return False # shutdown the thread
160+
with customized_gc_thresholds(self.garbage_collection_gen0_threshold):
161+
if (res := self.result_queue.get()) is None:
162+
return False # shutdown the thread
157163

158164
async_queues = []
159165
event_loop = None

tensorrt_llm/executor/worker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
5959
is_llm_executor: Optional[bool] = None,
6060
lora_config: Optional[LoraConfig] = None,
61+
garbage_collection_gen0_threshold: Optional[int] = None,
6162
) -> None:
6263
postproc_config = postproc_worker_config or PostprocWorkerConfig()
6364
super().__init__(
@@ -125,6 +126,8 @@ def _create_engine():
125126
create_py_executor
126127
create_executor = create_py_executor
127128
args["lora_config"] = lora_config
129+
args[
130+
"garbage_collection_gen0_threshold"] = garbage_collection_gen0_threshold
128131
elif executor_config.backend == "_autodeploy":
129132
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
130133
create_autodeploy_executor
@@ -595,6 +598,7 @@ def worker_main(
595598
is_llm_executor: Optional[
596599
bool] = True, # whether it's the main executor instance
597600
lora_config: Optional[LoraConfig] = None,
601+
garbage_collection_gen0_threshold: Optional[int] = None,
598602
) -> None:
599603
mpi_comm().barrier()
600604
print_colored_debug(f"Worker {mpi_rank()} entering worker_main...\n",
@@ -720,7 +724,8 @@ def notify_proxy_threads_to_quit():
720724
batched_logits_processor,
721725
postproc_worker_config=postproc_worker_config,
722726
is_llm_executor=is_llm_executor,
723-
lora_config=lora_config)
727+
lora_config=lora_config,
728+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
724729
except Exception as e:
725730
logger.error(f"Failed to initialize executor on rank {mpi_rank()}: {e}")
726731
logger.error(traceback.format_exc())

tensorrt_llm/llmapi/llm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,9 @@ def _build_model(self):
708708
postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir,
709709
),
710710
is_llm_executor=True,
711-
lora_config=self.args.lora_config)
711+
lora_config=self.args.lora_config,
712+
garbage_collection_gen0_threshold=self.args.
713+
garbage_collection_gen0_threshold)
712714

713715
@property
714716
def _on_trt_backend(self) -> bool:

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,12 @@ class BaseLlmArgs(BaseModel):
941941
default=None,
942942
description="The parser to separate reasoning content from output.")
943943

944+
garbage_collection_gen0_threshold: int = Field(
945+
default=20000,
946+
description=
947+
"Threshold for Python garbage collection of generation 0 objects."
948+
"Lower values trigger more frequent garbage collection.")
949+
944950
# TODO[Superjomn]: To deprecate this config.
945951
decoding_config: Optional[object] = Field(
946952
default=None,

tests/unittest/api_stability/references_committed/llm.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ methods:
105105
kv_cache_config:
106106
annotation: tensorrt_llm.llmapi.llm_args.KvCacheConfig
107107
default: null
108+
garbage_collection_gen0_threshold:
109+
annotation: int
110+
default: 20000
108111
return_annotation: None
109112
generate:
110113
parameters:

0 commit comments

Comments
 (0)