2626from .resource_manager import (KVCacheManager , MambaHybridCacheManager ,
2727 PeftCacheManager , ResourceManager ,
2828 ResourceManagerType )
29- from .sampler import (EarlyStopSampler , TorchSampler , TorchStarAttentionSampler ,
30- TRTLLMSampler )
29+ from .sampler import EarlyStopSampler , TorchSampler , TRTLLMSampler
3130from .scheduler import (BindCapacityScheduler , BindMicroBatchScheduler ,
3231 SimpleScheduler )
3332from .seq_slot_manager import SeqSlotManager
@@ -506,6 +505,7 @@ def create_py_executor_instance(
506505 model_engine = model_engine ,
507506 sampler = sampler ,
508507 dist = dist ,
508+ max_num_sequences = max_num_sequences ,
509509 disable_overlap_scheduler = pytorch_backend_config .
510510 disable_overlap_scheduler ,
511511 max_batch_size = executor_config .max_batch_size ,
@@ -517,31 +517,44 @@ def create_py_executor_instance(
517517 garbage_collection_gen0_threshold = garbage_collection_gen0_threshold )
518518
519519
520- def instantiate_sampler (model_engine : PyTorchModelEngine ,
520+ def create_torch_sampler_args (executor_config : ExecutorConfig , mapping : Mapping ,
521+ * , max_seq_len : int , mixed_sampler : bool ):
522+ max_num_sequences = executor_config .max_batch_size * mapping .pp_size
523+ max_draft_tokens = (0 if executor_config .speculative_config is None else
524+ executor_config .speculative_config .max_draft_tokens )
525+ return TorchSampler .Args (
526+ max_seq_len = max_seq_len ,
527+ max_draft_tokens = max_draft_tokens ,
528+ max_num_sequences = max_num_sequences ,
529+ max_beam_width = executor_config .max_beam_width ,
530+ mixed_sampler = mixed_sampler ,
531+ )
532+
533+
534+ def instantiate_sampler (engine : PyTorchModelEngine ,
521535 executor_config : ExecutorConfig ,
522536 pytorch_backend_config : PyTorchConfig ,
523537 mapping : Mapping ):
538+ sampler_args = create_torch_sampler_args (
539+ executor_config ,
540+ mapping ,
541+ max_seq_len = engine .max_seq_len ,
542+ mixed_sampler = pytorch_backend_config .mixed_sampler )
524543 if mapping .cp_config .get ('cp_type' ) == 'star_attention' :
525544 assert pytorch_backend_config .attn_backend == "FLASHINFER_STAR_ATTENTION" , "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
526- sampler = TorchStarAttentionSampler (
527- max_seq_len = model_engine .max_seq_len )
528- elif model_engine .spec_config is not None and model_engine .spec_config .spec_dec_mode .has_spec_decoder (
545+ return TorchSampler (sampler_args )
546+ if engine .spec_config is not None and engine .spec_config .spec_dec_mode .has_spec_decoder (
529547 ):
530- sampler = get_spec_decoder (max_seq_len = model_engine .max_seq_len ,
531- spec_config = model_engine .spec_config )
532- elif pytorch_backend_config .enable_trtllm_sampler :
548+ return get_spec_decoder (sampler_args , engine .spec_config )
549+ if pytorch_backend_config .enable_trtllm_sampler :
533550 decoding_mode = get_decoding_mode (executor_config )
534- sampler = TRTLLMSampler (
535- executor_config , model_engine . model , model_engine . dtype , mapping ,
536- decoding_mode , pytorch_backend_config .disable_overlap_scheduler )
537- elif not model_engine .model .model_config .is_generation :
551+ return TRTLLMSampler (executor_config , engine . model , engine . dtype ,
552+ mapping , decoding_mode ,
553+ pytorch_backend_config .disable_overlap_scheduler )
554+ if not engine .model .model_config .is_generation :
538555 # NOTE: choose sampler based on model type
539- sampler = EarlyStopSampler ()
540- else :
541- sampler = TorchSampler (
542- max_seq_len = model_engine .max_seq_len ,
543- mixed_sampler = pytorch_backend_config .mixed_sampler )
544- return sampler
556+ return EarlyStopSampler ()
557+ return TorchSampler (sampler_args )
545558
546559
547560def get_decoding_mode (executor_config : ExecutorConfig ) -> DecodingMode :
0 commit comments