Skip to content

Commit 0ac14da

Browse files
committed
add extra sampling strategies
1 parent 1facb55 commit 0ac14da

File tree

2 files changed

+557
-33
lines changed

2 files changed

+557
-33
lines changed

medusa/model/utils.py

Lines changed: 155 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch.nn.functional as F
23

34
TOPK=10 # topk for sparse tree
45

@@ -192,8 +193,37 @@ def reset_past_key_values(passed_key_values):
192193
passed_key_values[i][j].current_length.fill_(0)
193194
return passed_key_values
194195

196+
def get_nucleus_one_token(logit, temperature, top_p):
197+
# input [nxC]
198+
logit = logit[:, :-1] / temperature
199+
probs = torch.softmax(logit, dim=-1)
200+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
201+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
202+
sorted_indices_to_remove = cum_probs > top_p
203+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
204+
sorted_indices_to_remove[..., 0] = 0
205+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
206+
logit[indices_to_remove] = float('-inf')
207+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
208+
return sampled_tokens
209+
210+
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
211+
# input [nxC]
212+
logit = logit[:, :-1] / temperature
213+
probs = torch.softmax(logit, dim=-1)
214+
entropy = -torch.sum(
215+
probs * torch.log(probs + 1e-5), dim=-1
216+
)
217+
threshold = torch.minimum(
218+
torch.ones_like(entropy) * posterior_threshold,
219+
torch.exp(-entropy) * posterior_alpha,
220+
)
221+
indices_to_remove = probs < threshold.unsqueeze(-1)
222+
logit[indices_to_remove] = float('-inf')
223+
sampled_tokens = torch.multinomial(F.softmax(logit, dim=-1), 1)
224+
return sampled_tokens
195225

196-
def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices):
226+
def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, temperature = 0, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = False):
197227
"""
198228
Generate candidates based on provided logits and indices.
199229
@@ -208,8 +238,15 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices):
208238
"""
209239

210240
# Greedy decoding: Select the most probable candidate from the original logits.
211-
candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0)
212-
241+
if temperature == 0 or fast:
242+
candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0)
243+
else:
244+
if sampling == 'typical':
245+
candidates_logit = get_typical_one_token(logits[:, -1], temperature, posterior_threshold, posterior_alpha).squeeze(0)
246+
elif sampling == 'nucleus':
247+
candidates_logit = get_nucleus_one_token(logits[:, -1], temperature, top_p).squeeze(0)
248+
else:
249+
raise NotImplementedError
213250
# Extract the TOPK candidates from the medusa logits.
214251
candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices
215252

@@ -271,9 +308,66 @@ def tree_decoding(
271308
medusa_logits = tree_medusa_logits[:, 0, retrieve_indices]
272309
return medusa_logits, logits, outputs
273310

311+
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
312+
313+
# adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
314+
315+
# Apply temperature
316+
317+
logits = logits[:, :-1] / temperature
318+
319+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
320+
logits = logits.view(n_samples*n_tokens, -1)
321+
322+
# Convert to probabilities (softmax)
323+
probs = F.softmax(logits, dim=-1)
324+
# Sort the probabilities
325+
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
326+
327+
# Compute cumulative probabilities
328+
cum_probs = torch.cumsum(sorted_logits, dim=-1)
329+
330+
# Create mask for the top-p nucleus
331+
sorted_indices_to_remove = cum_probs > top_p
332+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
333+
sorted_indices_to_remove[..., 0] = 0
334+
335+
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
336+
337+
338+
# Remove low-probability tokens
339+
logits[indices_to_remove] = float('-inf')
340+
# Sample from the remaining tokens
341+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
342+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
343+
# Create a mask for selected tokens
344+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
345+
346+
return posterior_mask
347+
348+
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
349+
logits = logits[:, :-1] / temperature
350+
n_samples, n_tokens = logits.shape[0], logits.shape[1]
351+
logits = logits.view(n_samples*n_tokens, -1)
352+
probs = F.softmax(logits, dim=-1)
353+
entropy = -torch.sum(
354+
probs * torch.log(probs + 1e-5), dim=-1
355+
)
356+
threshold = torch.minimum(
357+
torch.ones_like(entropy) * posterior_threshold,
358+
torch.exp(-entropy) * posterior_alpha,
359+
)
360+
indices_to_remove = probs < threshold.unsqueeze(-1)
361+
logits[indices_to_remove] = float('-inf')
362+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
363+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
364+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
365+
return posterior_mask
366+
367+
274368

275369
def evaluate_posterior(
276-
logits, candidates, temperature, posterior_threshold, posterior_alpha
370+
logits, candidates, temperature, posterior_threshold=0.3, posterior_alpha = 0.09, top_p=0.8, sampling = 'typical', fast = True
277371
):
278372
"""
279373
Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.
@@ -307,36 +401,64 @@ def evaluate_posterior(
307401
else:
308402
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
309403
return best_candidate, accept_length
310-
# Calculate posterior probabilities and thresholds for candidate selection
311-
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
312-
candidates_prob = torch.gather(
313-
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
314-
).squeeze(-1)
315-
posterior_entropy = -torch.sum(
316-
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
317-
) # torch.sum(torch.log(*)) is faster than torch.prod
318-
threshold = torch.minimum(
319-
torch.ones_like(posterior_entropy) * posterior_threshold,
320-
torch.exp(-posterior_entropy) * posterior_alpha,
321-
)
322-
posterior_mask = candidates_prob > threshold
323-
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
324-
325-
# Choose the best candidate based on the evaluated posterior probabilities
326-
accept_length = candidates_accept_length.max()
327-
if accept_length == 0:
328-
# If no candidates are accepted, just choose the first one
329-
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
404+
405+
if sampling == 'typical':
406+
if fast:
407+
posterior_prob = torch.softmax(logits[:, :-1] / temperature, dim=-1)
408+
candidates_prob = torch.gather(
409+
posterior_prob, dim=-1, index=candidates[:, 1:].unsqueeze(-1)
410+
).squeeze(-1)
411+
posterior_entropy = -torch.sum(
412+
posterior_prob * torch.log(posterior_prob + 1e-5), dim=-1
413+
) # torch.sum(torch.log(*)) is faster than torch.prod
414+
threshold = torch.minimum(
415+
torch.ones_like(posterior_entropy) * posterior_threshold,
416+
torch.exp(-posterior_entropy) * posterior_alpha,
417+
)
418+
posterior_mask = candidates_prob > threshold
419+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
420+
421+
# Choose the best candidate based on the evaluated posterior probabilities
422+
accept_length = candidates_accept_length.max()
423+
if accept_length == 0:
424+
# If no candidates are accepted, just choose the first one
425+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
426+
else:
427+
best_candidates = torch.where(candidates_accept_length == accept_length)[0]
428+
# Accept the best one according to likelihood
429+
likelihood = torch.sum(
430+
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
431+
)
432+
best_candidate = best_candidates[torch.argmax(likelihood)]
433+
return best_candidate, accept_length
434+
# Calculate posterior probabilities and thresholds for candidate selection
435+
posterior_mask = get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha, fast)
436+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
437+
# Choose the best candidate based on the evaluated posterior probabilities
438+
accept_length = candidates_accept_length.max()
439+
440+
if accept_length == 0:
441+
# If no candidates are accepted, just choose the first one
442+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
443+
else:
444+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
445+
# Accept the best one according to likelihood
446+
return best_candidate, accept_length
447+
448+
if sampling == 'nucleus':
449+
assert top_p < 1.0 + 1e-6, "top_p should between 0 and 1"
450+
posterior_mask = get_nucleus_posterior_mask(logits, candidates, temperature, top_p)
451+
candidates_accept_length = (torch.cumprod(posterior_mask, dim=1)).sum(dim=1)
452+
accept_length = candidates_accept_length.max()
453+
# Choose the best candidate
454+
if accept_length == 0:
455+
# Default to the first candidate if none are accepted
456+
best_candidate = torch.tensor(0, dtype=torch.long, device=candidates.device)
457+
else:
458+
best_candidate = torch.argmax(candidates_accept_length).to(torch.long)
459+
return best_candidate, accept_length
330460
else:
331-
best_candidates = torch.where(candidates_accept_length == accept_length)[0]
332-
# Accept the best one according to likelihood
333-
likelihood = torch.sum(
334-
torch.log(candidates_prob[best_candidates, :accept_length]), dim=-1
335-
)
336-
best_candidate = best_candidates[torch.argmax(likelihood)]
337-
return best_candidate, accept_length
338-
339-
461+
raise NotImplementedError
340462
def update_inference_inputs(
341463
input_ids,
342464
candidates,

0 commit comments

Comments
 (0)