Skip to content

Commit c5b34f1

Browse files
authored
Skip unnecessary sampling and fix the random offset (#4068)
* optimize multinomial sampling kernel * remove * add comments * optimize * remove sync * recovery * remove print * fix * optimize output pipeline * skip unnecessary sampling * add rand offsets * add more comment
1 parent a9a24fb commit c5b34f1

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

lmdeploy/pytorch/engine/logits_process.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,17 @@ def _apply_custom_logits_processors(batched_logits_processors, all_ids, logits):
126126
return logits
127127

128128

129+
def _torch_topk(x: torch.Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True):
130+
if k == 1:
131+
# torch.topk would not fallback to torch.max/torch.min automatically
132+
if largest:
133+
return torch.max(x, dim=dim, keepdim=True)
134+
else:
135+
return torch.min(x, dim=dim, keepdim=True)
136+
else:
137+
return torch.topk(x, k, dim=dim, largest=largest, sorted=sorted)
138+
139+
129140
class FusedLogitsProcessor:
130141
"""Custom logits processor."""
131142

@@ -266,7 +277,7 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
266277
if max_topk <= 0:
267278
scores, indices = logits.sort(1, descending=True)
268279
else:
269-
scores, indices = logits.topk(max_topk, dim=1)
280+
scores, indices = _torch_topk(logits, max_topk, dim=1)
270281
result = __random_sampling(scores, indices)
271282

272283
if self.guided_decoding_manager and self.guided_processors:
@@ -285,7 +296,7 @@ def compute_logprobs(self, raw_logprobs: torch.Tensor, token_ids: torch.LongTens
285296
logprobs = raw_logprobs.gather(-1, indices)
286297
num_logprobs = self.sampling_inputs.max_num_logprobs
287298
if num_logprobs > 0:
288-
topk_logprobs, topk_indices = raw_logprobs.topk(num_logprobs, dim=-1)
299+
topk_logprobs, topk_indices = _torch_topk(raw_logprobs, num_logprobs, dim=-1)
289300
logprobs = torch.cat([logprobs, topk_logprobs], dim=-1)
290301
indices = torch.cat([indices, topk_indices], dim=-1)
291302

lmdeploy/pytorch/strategies/ar/model_agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def slice_extra_inputs(self, extra_inputs: ARExtraInputs, seq_length: torch.Long
7070
def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids: torch.Tensor):
7171
"""step."""
7272
sampling_inputs.num_ignore_eos = sampling_inputs.num_ignore_eos - 1
73+
if sampling_inputs.random_offsets is not None:
74+
# random offset is used to generate random numbers for multinomial sampling
75+
# so we need to increase it by 1 at each step
76+
sampling_inputs.random_offsets += 1
7377

7478
all_ids = sampling_inputs.all_ids
7579
if all_ids is not None:

lmdeploy/pytorch/strategies/ar/sampling.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __gather_params():
6565
param = seq.sampling_param
6666
temperature[idx] = param.temperature
6767
repetition_penalty[idx] = param.repetition_penalty
68-
top_k[idx] = param.top_k
68+
top_k[idx] = max(0, param.top_k)
6969
top_p[idx] = param.top_p
7070
min_p[idx] = param.min_p
7171
random_offsets[idx] = seq.num_valid_ids
@@ -129,6 +129,9 @@ def __get_bad_words(bad_words):
129129
repetition_penalty = torch.tensor(repetition_penalty)
130130

131131
temperature = torch.tensor(temperature)
132+
if (temperature == 1.0).all():
133+
# skip temperature processing if all temperature are 1.0
134+
temperature = None
132135

133136
bad_words, bad_mask = __get_bad_words(bad_words)
134137
stop_words, stop_mask = __get_bad_words(stop_words)
@@ -144,6 +147,10 @@ def __get_bad_words(bad_words):
144147
random_offsets = None
145148
else:
146149
top_k = torch.tensor(top_k)
150+
if (top_k == max_top_k).all():
151+
# we would perform max_top_k before top_k
152+
# if all top_k are same, we do not need to filter topk again
153+
top_k = None
147154
top_p, min_top_p = __get_topp(top_p)
148155
min_p = __get_minp(min_p)
149156
random_seeds = torch.tensor(random_seeds)

lmdeploy/pytorch/strategies/dllm/model_agent.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ def _step_sampling_inputs(self, sampling_inputs: SamplingInputs, next_token_ids:
169169
num_ignore_eos = sampling_inputs.num_ignore_eos.view(-1, dllm_block_size)
170170
num_ignore_eos = torch.where(is_unmasked, num_ignore_eos - dllm_block_size, num_ignore_eos)
171171
sampling_inputs.num_ignore_eos = num_ignore_eos.flatten()
172+
if sampling_inputs.random_offsets is not None:
173+
# random offset is used to generate random numbers for multinomial sampling
174+
# so we need to increase it by 1 at each step
175+
sampling_inputs.random_offsets += 1
172176
return sampling_inputs
173177

174178
def make_stopping_criteria(self, seqs: SeqList) -> DLLMStoppingCriteria:

0 commit comments

Comments
 (0)