@@ -750,16 +750,16 @@ def get_pythonized_sample_results(
750
750
if sampling_type not in sample_metadata :
751
751
continue
752
752
(seq_group_id , seq_groups ) = sample_metadata [sampling_type ]
753
- if sampling_type == SamplingType .FORCED :
754
- sample_results = _forced_sample (seq_groups , forced_samples )
755
- elif sampling_type == SamplingType .GREEDY :
753
+ if sampling_type == SamplingType .GREEDY :
756
754
sample_results = _greedy_sample (seq_groups , greedy_samples )
757
755
elif sampling_type in (SamplingType .RANDOM , SamplingType .RANDOM_SEED ):
758
756
sample_results = _random_sample (seq_groups ,
759
757
multinomial_samples [sampling_type ])
760
758
elif sampling_type == SamplingType .BEAM :
761
759
sample_results = _beam_search_sample (seq_groups ,
762
760
beam_search_logprobs )
761
+ elif sampling_type == SamplingType .FORCED :
762
+ sample_results = _forced_sample (seq_groups , forced_samples )
763
763
sample_results_dict .update (zip (seq_group_id , sample_results ))
764
764
765
765
return [
@@ -825,19 +825,8 @@ def _sample_with_torch(
825
825
seq_groups = [sampling_metadata .seq_groups [i ] for i in seq_group_id ]
826
826
sample_metadata [sampling_type ] = (seq_group_id , seq_groups )
827
827
long_sample_indices = sample_indices .long ()
828
- if sampling_type == SamplingType .FORCED :
829
- if (seq_groups [0 ].sampling_params .future_context is not None ):
830
- forced_samples = torch .tensor ([
831
- seq_groups [0 ].sampling_params .future_context [0 ][min (
832
- len (sampling_metadata .seq_groups [0 ].seq_data [
833
- sampling_params .cntr ].output_token_ids ),
834
- len (seq_groups [0 ].sampling_params .future_context [0 ]) -
835
- 1 )]
836
- ])
837
- else :
838
- forced_samples = torch .argmax (logprobs [long_sample_indices ],
839
- dim = - 1 )
840
- elif sampling_type == SamplingType .GREEDY :
828
+
829
+ if sampling_type == SamplingType .GREEDY :
841
830
greedy_samples = torch .argmax (logprobs [long_sample_indices ],
842
831
dim = - 1 )
843
832
@@ -886,6 +875,18 @@ def _sample_with_torch(
886
875
887
876
elif sampling_type == SamplingType .BEAM :
888
877
beam_search_logprobs = logprobs [sample_indices ]
878
+ elif sampling_type == SamplingType .FORCED :
879
+ if (seq_groups [0 ].sampling_params .future_context is not None ):
880
+ forced_samples = torch .tensor ([
881
+ seq_groups [0 ].sampling_params .future_context [0 ][min (
882
+ len (sampling_metadata .seq_groups [0 ].seq_data [
883
+ sampling_params .cntr ].output_token_ids ),
884
+ len (seq_groups [0 ].sampling_params .future_context [0 ]) -
885
+ 1 )]
886
+ ])
887
+ else :
888
+ forced_samples = torch .argmax (logprobs [long_sample_indices ],
889
+ dim = - 1 )
889
890
else :
890
891
raise ValueError (f"Unsupported sampling type: { sampling_type } " )
891
892
0 commit comments