@@ -350,6 +350,7 @@ def create(
350350 postproc_worker_config : Optional [PostprocWorkerConfig ] = None ,
351351 is_llm_executor : Optional [bool ] = None ,
352352 lora_config : Optional [LoraConfig ] = None ,
353+ garbage_collection_gen0_threshold : Optional [int ] = None ,
353354 ) -> Union ["GenerationExecutorProxy" , "GenerationExecutorWorker" ]:
354355 # local imports to avoid cyclic importing
355356 from .proxy import GenerationExecutorProxy
@@ -393,7 +394,9 @@ def create(
393394 model_world_size = model_world_size ,
394395 mpi_session = mpi_session ,
395396 postproc_worker_config = postproc_worker_config ,
396- is_llm_executor = is_llm_executor )
397+ is_llm_executor = is_llm_executor ,
398+ garbage_collection_gen0_threshold =
399+ garbage_collection_gen0_threshold )
397400
398401 # WAR: For the performance of gathering logits, we use single process worker
399402 # for TP1 to avoid the large overhead of IPC.
@@ -404,7 +407,9 @@ def create(
404407 "Using single process worker for TP1, this may hurt streaming generation performance."
405408 )
406409 return GenerationExecutorWorker (** worker_kwargs ,
407- is_llm_executor = is_llm_executor )
410+ is_llm_executor = is_llm_executor ,
411+ garbage_collection_gen0_threshold =
412+ garbage_collection_gen0_threshold )
408413
409414 # For single-gpu case:
410415 # Partition the workload to multiple process for streaming performance.
@@ -416,7 +421,9 @@ def create(
416421 model_world_size = model_world_size ,
417422 mpi_session = None , # use mpi4py
418423 postproc_worker_config = postproc_worker_config ,
419- is_llm_executor = is_llm_executor )
424+ is_llm_executor = is_llm_executor ,
425+ garbage_collection_gen0_threshold =
426+ garbage_collection_gen0_threshold )
420427 else :
421428 ctx = multiprocessing .get_context ("spawn" )
422429 # The ProcessPoolExecutorSession is used to support Windows, as mpi4py cannot.
@@ -427,7 +434,9 @@ def create(
427434 model_world_size = model_world_size ,
428435 mpi_session = mpi_session ,
429436 postproc_worker_config = postproc_worker_config ,
430- is_llm_executor = is_llm_executor )
437+ is_llm_executor = is_llm_executor ,
438+ garbage_collection_gen0_threshold =
439+ garbage_collection_gen0_threshold )
431440
432441 def wait_first_completed (
433442 self , futures : List [GenerationResult ]
0 commit comments