@@ -1792,7 +1792,8 @@ def _write_finish_reasons(
17921792 if not single_token_stop_words_only
17931793 else self ._are_stop_words_single_token
17941794 )
1795- batched_finish_reasons [:, stop_word_indices ] = torch .where (
1795+ batched_finish_reasons_stop_words = batched_finish_reasons [:, stop_word_indices ]
1796+ _ = batched_finish_reasons_stop_words .masked_fill_ (
17961797 stop_words_func (
17971798 stop_seq_slots ,
17981799 stop_tokens ,
@@ -1801,18 +1802,17 @@ def _write_finish_reasons(
18011802 else num_accepted_tokens ,
18021803 ),
18031804 FinishReason .STOP_WORDS .value ,
1804- batched_finish_reasons [:, stop_word_indices ],
18051805 )
1806+ batched_finish_reasons [:, stop_word_indices ] = batched_finish_reasons_stop_words
18061807
1807- batched_finish_reasons = torch . where (
1808+ _ = batched_finish_reasons . masked_fill_ (
18081809 self ._are_max_length (seq_lens , store .max_lengths_cuda [seq_slots ]),
18091810 FinishReason .LENGTH .value ,
1810- batched_finish_reasons ,
18111811 )
1812- batched_finish_reasons = torch .where (
1812+
1813+ _ = batched_finish_reasons .masked_fill_ (
18131814 self ._are_end_id (store .end_ids_cuda [seq_slots ], tokens ),
18141815 FinishReason .END_ID .value ,
1815- batched_finish_reasons ,
18161816 )
18171817
18181818 finish_reasons [:, seq_slots ] = batched_finish_reasons
@@ -1916,7 +1916,7 @@ def _are_stop_words(
19161916 # Fill in the new tokens at the end of the past tokens buffer
19171917 full_tokens [- self ._max_tokens :] = tokens
19181918 # short words are padded with _PAD_STOP_WORD_TOKEN_ID, so we need to mask them
1919- mask = stop_words ! = self ._PAD_STOP_WORD_TOKEN_ID
1919+ mask = stop_words = = self ._PAD_STOP_WORD_TOKEN_ID
19201920 matches = torch .empty (
19211921 (
19221922 self ._max_tokens ,
@@ -1941,15 +1941,15 @@ def _are_stop_words(
19411941 stop_words_for_match = stop_words .unsqueeze (0 )
19421942 _ = torch .eq (full_tokens_for_match , stop_words_for_match , out = matches )
19431943 # Mask the padding tokens
1944- matches_after_mask = torch . where (
1945- mask .unsqueeze (0 ).expand (self ._max_tokens , - 1 , - 1 , - 1 , - 1 ), matches , True
1944+ _ = matches . masked_fill_ (
1945+ mask .unsqueeze (0 ).expand (self ._max_tokens , - 1 , - 1 , - 1 , - 1 ), True
19461946 )
19471947 # Update the past tokens storage for the next iteration
19481948 store .past_tokens_cuda [:, seq_slots ] = full_tokens
19491949 # Return the result
19501950 word_len_dim = 2
19511951 num_words_dim = 1
1952- return torch .any (matches_after_mask .all (dim = word_len_dim ), dim = num_words_dim )
1952+ return torch .any (matches .all (dim = word_len_dim ), dim = num_words_dim )
19531953
19541954 @nvtx_range ("_are_stop_words_single_token" )
19551955 def _are_stop_words_single_token (
@@ -3721,8 +3721,10 @@ def _sample_batched_by_strategy(
37213721 group_logits_indices_for_processed_logprobs_cuda
37223722 ]
37233723 current_softmax_cuda = group_softmax_cuda [logit_indices_for_processed_logprobs_cuda ]
3724- processed_logits_cuda = torch .where (
3725- current_softmax_cuda > 0 , current_logits_cuda , float ("-inf" )
3724+
3725+ # processed_logits_cuda is an alias to current_logits_cuda after this operation
3726+ processed_logits_cuda = current_logits_cuda .masked_fill_ (
3727+ current_softmax_cuda == 0 , float ("-inf" )
37263728 )
37273729 temperature_for_processed_logprobs = group_temperature_cuda
37283730 if isinstance (temperature_for_processed_logprobs , torch .Tensor ):
0 commit comments