69
69
FullAttentionSpec , KVCacheConfig ,
70
70
KVCacheGroupSpec , KVCacheSpec ,
71
71
MambaSpec , SlidingWindowSpec )
72
- from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , AsyncModelRunnerOutput ,
73
- DraftTokenIds , LogprobsTensors , ModelRunnerOutput )
72
+ from vllm .v1 .outputs import (EMPTY_MODEL_RUNNER_OUTPUT , DraftTokenIds ,
73
+ LogprobsTensors , ModelRunnerOutput )
74
74
from vllm .v1 .pool .metadata import PoolingMetadata
75
75
from vllm .v1 .sample .logits_processor import LogitsProcessors , build_logitsprocs
76
76
from vllm .v1 .sample .metadata import SamplingMetadata
@@ -238,8 +238,6 @@ def __init__(
238
238
is_pooling_model = self .is_pooling_model ,
239
239
)
240
240
241
- self .use_async_scheduling = self .scheduler_config .async_scheduling
242
-
243
241
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
244
242
# The convention is different.
245
243
# self.cudagraph_batch_sizes sorts in ascending order.
@@ -711,73 +709,6 @@ def _get_cumsum_and_arange(
711
709
712
710
return cu_num_tokens , arange
713
711
714
- def _prepare_input_ids (self , total_num_scheduled_tokens : int ,
715
- cu_num_tokens : np .ndarray ) -> None :
716
- """Prepare the input IDs for the current batch.
717
-
718
- Carefully handles the `prev_sampled_token_ids` which can be cached
719
- from the previous engine iteration, in which case those tokens on the
720
- GPU need to be copied into the corresponding slots into input_ids."""
721
-
722
- if self .input_batch .prev_sampled_token_ids is not None :
723
- # Async scheduling case, we need to copy the sampled token ids
724
- # from the previous iteration.
725
- prev_req_id_to_index = self .input_batch .prev_req_id_to_index
726
- current_req_id_to_index = self .input_batch .req_id_to_index
727
- assert prev_req_id_to_index is not None
728
- common_req_ids = set (prev_req_id_to_index .keys ()).intersection (
729
- set (current_req_id_to_index .keys ()))
730
- if common_req_ids :
731
- current_common_req_indices = [
732
- current_req_id_to_index [req_id ]
733
- for req_id in common_req_ids
734
- ]
735
- prev_common_req_indices = [
736
- prev_req_id_to_index [req_id ] for req_id in common_req_ids
737
- ]
738
- # We need to compute the flattened input_ids index of the
739
- # last token in each common request.
740
- flattened_indices = [
741
- int (cu_num_tokens [idx ]) - 1
742
- for idx in current_common_req_indices
743
- ]
744
- if len (flattened_indices ) < total_num_scheduled_tokens :
745
- # If not all requests are decodes from the last iteration,
746
- # We need to copy the input_ids_cpu to the GPU first.
747
- self .input_ids .copy_to_gpu (total_num_scheduled_tokens )
748
- if flattened_indices == prev_common_req_indices and \
749
- set (flattened_indices ) == \
750
- set (range (len (flattened_indices ))):
751
- # Common-case optimization: the batch is unchanged
752
- # and no reordering happened.
753
- # The indices are both the same permutation of 0..N-1
754
- self .input_ids .gpu [:len (flattened_indices )].copy_ (
755
- self .input_batch .prev_sampled_token_ids [:len (
756
- flattened_indices )].squeeze (1 ),
757
- non_blocking = True )
758
- else :
759
- # Upload the index tensors asynchronously
760
- # so the scatter can be non-blocking
761
- input_ids_index_tensor = torch .tensor (
762
- flattened_indices ,
763
- dtype = torch .int64 ,
764
- pin_memory = self .pin_memory ).to (self .device ,
765
- non_blocking = True )
766
- prev_common_req_indices_tensor = torch .tensor (
767
- prev_common_req_indices ,
768
- dtype = torch .int64 ,
769
- pin_memory = self .pin_memory ).to (self .device ,
770
- non_blocking = True )
771
- self .input_ids .gpu .scatter_ (
772
- dim = 0 ,
773
- index = input_ids_index_tensor ,
774
- src = self .input_batch .prev_sampled_token_ids [
775
- prev_common_req_indices_tensor ].squeeze (1 ))
776
- else :
777
- self .input_ids .copy_to_gpu (total_num_scheduled_tokens )
778
- else :
779
- self .input_ids .copy_to_gpu (total_num_scheduled_tokens )
780
-
781
712
def _prepare_inputs (
782
713
self , scheduler_output : "SchedulerOutput"
783
714
) -> tuple [PerLayerAttnMetadata , torch .Tensor ,
@@ -869,8 +800,7 @@ def _prepare_inputs(
869
800
max_seq_len = self .seq_lens .np [:num_reqs ].max ().item ()
870
801
871
802
# Copy the tensors to the GPU.
872
- self ._prepare_input_ids (total_num_scheduled_tokens , cu_num_tokens )
873
-
803
+ self .input_ids .copy_to_gpu (total_num_scheduled_tokens )
874
804
if self .uses_mrope :
875
805
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
876
806
self .mrope_positions .gpu [:, :total_num_scheduled_tokens ].copy_ (
@@ -986,10 +916,6 @@ def _prepare_inputs(
986
916
builder ,
987
917
)
988
918
989
- < << << << HEAD
990
-
991
- == == == =
992
- >> >> >> > nm / sage / dbo - full - cudagraphs
993
919
if ubatch_slices is not None :
994
920
common_attn_metadata_list = split_attn_metadata (
995
921
ubatch_slices , common_attn_metadata )
@@ -1637,7 +1563,6 @@ def get_dp_padding_ubatch(
1637
1563
should_ubatch = False
1638
1564
1639
1565
# Note that we compute the number of padded tokens per ubatch
1640
-
1641
1566
(should_ubatch ,
1642
1567
num_tokens_across_dp ) = self .should_ubatch_with_num_tokens (
1643
1568
should_ubatch , num_tokens_unpadded // 2 , num_tokens_per_ubatch )
@@ -1724,7 +1649,7 @@ def execute_model(
1724
1649
self ,
1725
1650
scheduler_output : "SchedulerOutput" ,
1726
1651
intermediate_tensors : Optional [IntermediateTensors ] = None ,
1727
- ) -> Union [ModelRunnerOutput , AsyncModelRunnerOutput , IntermediateTensors ]:
1652
+ ) -> Union [ModelRunnerOutput , IntermediateTensors ]:
1728
1653
self ._update_states (scheduler_output )
1729
1654
if not scheduler_output .total_num_scheduled_tokens :
1730
1655
if not has_kv_transfer_group ():
@@ -1927,12 +1852,6 @@ def execute_model(
1927
1852
# so that we could clear the sampled tokens before returning.
1928
1853
discard_sampled_tokens_req_indices .append (i )
1929
1854
1930
- # Copy some objects so they don't get modified after returning.
1931
- # This is important when using async scheduling.
1932
- req_ids_output_copy = self .input_batch .req_ids .copy ()
1933
- req_id_to_index_output_copy = \
1934
- self .input_batch .req_id_to_index .copy ()
1935
-
1936
1855
# NOTE: GPU -> CPU Sync happens here.
1937
1856
# Move as many CPU operations as possible before this sync point.
1938
1857
logprobs_tensors = sampler_output .logprobs_tensors
@@ -1945,54 +1864,29 @@ def execute_model(
1945
1864
scheduler_output .num_scheduled_tokens ,
1946
1865
)
1947
1866
1948
- num_sampled_tokens = sampler_output . sampled_token_ids . shape [ 0 ]
1867
+ # Get the valid generated tokens.
1949
1868
sampled_token_ids = sampler_output .sampled_token_ids
1950
- if not self .use_async_scheduling :
1951
- # Get the valid generated tokens.
1952
- max_gen_len = sampled_token_ids .shape [- 1 ]
1953
- if max_gen_len == 1 :
1954
- # No spec decode tokens.
1955
- valid_sampled_token_ids = self ._to_list (sampled_token_ids )
1956
- else :
1957
- # Includes spec decode tokens.
1958
- valid_sampled_token_ids = self .rejection_sampler .parse_output (
1959
- sampled_token_ids ,
1960
- self .input_batch .vocab_size ,
1961
- )
1962
- # Mask out the sampled tokens that should not be sampled.
1963
- for i in discard_sampled_tokens_req_indices :
1964
- valid_sampled_token_ids [i ].clear ()
1869
+ max_gen_len = sampled_token_ids .shape [- 1 ]
1870
+ if max_gen_len == 1 :
1871
+ # No spec decode tokens.
1872
+ valid_sampled_token_ids = self ._to_list (sampled_token_ids )
1965
1873
else :
1966
- valid_sampled_token_ids = []
1967
- invalid_req_indices = list (discard_sampled_tokens_req_indices )
1968
- invalid_req_indices_set = set (invalid_req_indices )
1969
- assert sampled_token_ids .shape [- 1 ] == 1
1970
-
1971
- # Cache the sampled tokens on the GPU and avoid CPU sync.
1972
- # These will be copied into input_ids in the next step
1973
- # when preparing inputs.
1974
- self .input_batch .prev_sampled_token_ids = \
1975
- sampled_token_ids
1976
- self .input_batch .prev_sampled_token_ids_invalid_indices = \
1977
- invalid_req_indices_set
1978
- self .input_batch .prev_req_id_to_index = {
1979
- req_id : i
1980
- for i , req_id in enumerate (self .input_batch .req_ids )
1981
- if i not in invalid_req_indices_set
1982
- }
1874
+ # Includes spec decode tokens.
1875
+ valid_sampled_token_ids = self .rejection_sampler .parse_output (
1876
+ sampled_token_ids ,
1877
+ self .input_batch .vocab_size ,
1878
+ )
1879
+ # Mask out the sampled tokens that should not be sampled.
1880
+ for i in discard_sampled_tokens_req_indices :
1881
+ valid_sampled_token_ids [i ].clear ()
1983
1882
1984
1883
# Cache the sampled tokens in the model runner, so that the scheduler
1985
1884
# doesn't need to send them back.
1986
1885
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
1987
1886
# the sampled tokens back, because there's no direct communication
1988
1887
# between the first-stage worker and the last-stage worker.
1989
1888
req_ids = self .input_batch .req_ids
1990
- for req_idx in range (num_sampled_tokens ):
1991
- if self .use_async_scheduling :
1992
- sampled_ids = [- 1 ] * 1 if \
1993
- req_idx not in invalid_req_indices_set else None
1994
- else :
1995
- sampled_ids = valid_sampled_token_ids [req_idx ]
1889
+ for req_idx , sampled_ids in enumerate (valid_sampled_token_ids ):
1996
1890
if not sampled_ids :
1997
1891
continue
1998
1892
@@ -2007,7 +1901,6 @@ def execute_model(
2007
1901
start_idx :end_idx ] = sampled_ids
2008
1902
self .input_batch .num_tokens_no_spec [req_idx ] = end_idx
2009
1903
self .input_batch .num_tokens [req_idx ] = end_idx
2010
-
2011
1904
req_id = req_ids [req_idx ]
2012
1905
req_state = self .requests [req_id ]
2013
1906
req_state .output_token_ids .extend (sampled_ids )
@@ -2029,9 +1922,9 @@ def execute_model(
2029
1922
2030
1923
self .eplb_step ()
2031
1924
2032
- output = ModelRunnerOutput (
2033
- req_ids = req_ids_output_copy ,
2034
- req_id_to_index = req_id_to_index_output_copy ,
1925
+ return ModelRunnerOutput (
1926
+ req_ids = self . input_batch . req_ids ,
1927
+ req_id_to_index = self . input_batch . req_id_to_index ,
2035
1928
sampled_token_ids = valid_sampled_token_ids ,
2036
1929
logprobs = logprobs_lists ,
2037
1930
prompt_logprobs_dict = prompt_logprobs_dict ,
@@ -2040,15 +1933,6 @@ def execute_model(
2040
1933
num_nans_in_logits = num_nans_in_logits ,
2041
1934
)
2042
1935
2043
- if self .use_async_scheduling :
2044
- return AsyncModelRunnerOutput (
2045
- model_runner_output = output ,
2046
- sampled_token_ids = sampled_token_ids ,
2047
- invalid_req_indices = invalid_req_indices ,
2048
- )
2049
-
2050
- return output
2051
-
2052
1936
def take_draft_token_ids (self ) -> Optional [DraftTokenIds ]:
2053
1937
if self ._draft_token_ids is None :
2054
1938
return None
0 commit comments