Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit 9d8035b

Browse files
Swapping the order of sampling operations in the conditional selector. (#199)
Adding P3L measurement to the benchmarks collection tools. A more beautiful version of the code with "Swapping the order of sampling operations in the conditional selector. (#199)"
1 parent 7094103 commit 9d8035b

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

vllm/model_executor/layers/sampler.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -750,16 +750,16 @@ def get_pythonized_sample_results(
750750
if sampling_type not in sample_metadata:
751751
continue
752752
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
753-
if sampling_type == SamplingType.FORCED:
754-
sample_results = _forced_sample(seq_groups, forced_samples)
755-
elif sampling_type == SamplingType.GREEDY:
753+
if sampling_type == SamplingType.GREEDY:
756754
sample_results = _greedy_sample(seq_groups, greedy_samples)
757755
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
758756
sample_results = _random_sample(seq_groups,
759757
multinomial_samples[sampling_type])
760758
elif sampling_type == SamplingType.BEAM:
761759
sample_results = _beam_search_sample(seq_groups,
762760
beam_search_logprobs)
761+
elif sampling_type == SamplingType.FORCED:
762+
sample_results = _forced_sample(seq_groups, forced_samples)
763763
sample_results_dict.update(zip(seq_group_id, sample_results))
764764

765765
return [
@@ -825,19 +825,8 @@ def _sample_with_torch(
825825
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
826826
sample_metadata[sampling_type] = (seq_group_id, seq_groups)
827827
long_sample_indices = sample_indices.long()
828-
if sampling_type == SamplingType.FORCED:
829-
if (seq_groups[0].sampling_params.future_context is not None):
830-
forced_samples = torch.tensor([
831-
seq_groups[0].sampling_params.future_context[0][min(
832-
len(sampling_metadata.seq_groups[0].seq_data[
833-
sampling_params.cntr].output_token_ids),
834-
len(seq_groups[0].sampling_params.future_context[0]) -
835-
1)]
836-
])
837-
else:
838-
forced_samples = torch.argmax(logprobs[long_sample_indices],
839-
dim=-1)
840-
elif sampling_type == SamplingType.GREEDY:
828+
829+
if sampling_type == SamplingType.GREEDY:
841830
greedy_samples = torch.argmax(logprobs[long_sample_indices],
842831
dim=-1)
843832

@@ -886,6 +875,18 @@ def _sample_with_torch(
886875

887876
elif sampling_type == SamplingType.BEAM:
888877
beam_search_logprobs = logprobs[sample_indices]
878+
elif sampling_type == SamplingType.FORCED:
879+
if (seq_groups[0].sampling_params.future_context is not None):
880+
forced_samples = torch.tensor([
881+
seq_groups[0].sampling_params.future_context[0][min(
882+
len(sampling_metadata.seq_groups[0].seq_data[
883+
sampling_params.cntr].output_token_ids),
884+
len(seq_groups[0].sampling_params.future_context[0]) -
885+
1)]
886+
])
887+
else:
888+
forced_samples = torch.argmax(logprobs[long_sample_indices],
889+
dim=-1)
889890
else:
890891
raise ValueError(f"Unsupported sampling type: {sampling_type}")
891892

0 commit comments

Comments
 (0)