@@ -824,11 +824,16 @@ def create_py_executor_instance(
824824 virtual_memory_pools = virtual_memory_pools )
825825
826826
827- def create_torch_sampler_args (mapping : Mapping , * , max_seq_len : int ,
828- max_batch_size : int ,
829- speculative_config : SpeculativeConfig ,
830- max_beam_width : int ,
831- disable_flashinfer_sampling : bool ):
827+ def create_torch_sampler_args (
828+ mapping : Mapping ,
829+ * ,
830+ max_seq_len : int ,
831+ max_batch_size : int ,
832+ speculative_config : SpeculativeConfig ,
833+ max_beam_width : int ,
834+ disable_overlap_scheduler : bool ,
835+ disable_flashinfer_sampling : bool ,
836+ ):
832837 max_num_sequences = max_batch_size * mapping .pp_size
833838 max_draft_len = (0 if speculative_config is None else
834839 speculative_config .max_draft_len )
@@ -842,7 +847,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
842847 max_num_sequences = max_num_sequences ,
843848 max_beam_width = max_beam_width ,
844849 disable_flashinfer_sampling = disable_flashinfer_sampling ,
845- )
850+ disable_overlap_scheduler = disable_overlap_scheduler )
846851
847852
848853def instantiate_sampler (
@@ -865,6 +870,7 @@ def instantiate_sampler(
865870 max_batch_size = max_batch_size ,
866871 speculative_config = speculative_config ,
867872 max_beam_width = max_beam_width ,
873+ disable_overlap_scheduler = llm_args .disable_overlap_scheduler ,
868874 disable_flashinfer_sampling = disable_flashinfer_sampling ,
869875 )
870876 decoding_mode = get_decoding_mode (decoding_config = decoding_config ,
0 commit comments