11import torch
2+ import torch .nn .functional as F
23
34TOPK = 10 # topk for sparse tree
45
@@ -192,8 +193,37 @@ def reset_past_key_values(passed_key_values):
192193 passed_key_values [i ][j ].current_length .fill_ (0 )
193194 return passed_key_values
194195
196+ def get_nucleus_one_token (logit , temperature , top_p ):
197+ # input [nxC]
198+ logit = logit [:, :- 1 ] / temperature
199+ probs = torch .softmax (logit , dim = - 1 )
200+ sorted_logits , sorted_indices = torch .sort (probs , descending = True )
201+ cum_probs = torch .cumsum (sorted_logits , dim = - 1 )
202+ sorted_indices_to_remove = cum_probs > top_p
203+ sorted_indices_to_remove [..., 1 :] = sorted_indices_to_remove [..., :- 1 ].clone ()
204+ sorted_indices_to_remove [..., 0 ] = 0
205+ indices_to_remove = sorted_indices_to_remove .scatter (dim = 1 , index = sorted_indices , src = sorted_indices_to_remove )
206+ logit [indices_to_remove ] = float ('-inf' )
207+ sampled_tokens = torch .multinomial (F .softmax (logit , dim = - 1 ), 1 )
208+ return sampled_tokens
209+
210+ def get_typical_one_token (logit , temperature , posterior_threshold , posterior_alpha ):
211+ # input [nxC]
212+ logit = logit [:, :- 1 ] / temperature
213+ probs = torch .softmax (logit , dim = - 1 )
214+ entropy = - torch .sum (
215+ probs * torch .log (probs + 1e-5 ), dim = - 1
216+ )
217+ threshold = torch .minimum (
218+ torch .ones_like (entropy ) * posterior_threshold ,
219+ torch .exp (- entropy ) * posterior_alpha ,
220+ )
221+ indices_to_remove = probs < threshold .unsqueeze (- 1 )
222+ logit [indices_to_remove ] = float ('-inf' )
223+ sampled_tokens = torch .multinomial (F .softmax (logit , dim = - 1 ), 1 )
224+ return sampled_tokens
195225
196- def generate_candidates (medusa_logits , logits , tree_indices , retrieve_indices ):
226+ def generate_candidates (medusa_logits , logits , tree_indices , retrieve_indices , temperature = 0 , posterior_threshold = 0.3 , posterior_alpha = 0.09 , top_p = 0.8 , sampling = 'typical' , fast = False ):
197227 """
198228 Generate candidates based on provided logits and indices.
199229
@@ -208,8 +238,15 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices):
208238 """
209239
210240 # Greedy decoding: Select the most probable candidate from the original logits.
211- candidates_logit = torch .argmax (logits [:, - 1 ]).unsqueeze (0 )
212-
241+ if temperature == 0 or fast :
242+ candidates_logit = torch .argmax (logits [:, - 1 ]).unsqueeze (0 )
243+ else :
244+ if sampling == 'typical' :
245+ candidates_logit = get_typical_one_token (logits [:, - 1 ], temperature , posterior_threshold , posterior_alpha ).squeeze (0 )
246+ elif sampling == 'nucleus' :
247+ candidates_logit = get_nucleus_one_token (logits [:, - 1 ], temperature , top_p ).squeeze (0 )
248+ else :
249+ raise NotImplementedError
213250 # Extract the TOPK candidates from the medusa logits.
214251 candidates_medusa_logits = torch .topk (medusa_logits [:, 0 , - 1 ], TOPK , dim = - 1 ).indices
215252
@@ -271,9 +308,66 @@ def tree_decoding(
271308 medusa_logits = tree_medusa_logits [:, 0 , retrieve_indices ]
272309 return medusa_logits , logits , outputs
273310
311+ def get_nucleus_posterior_mask (logits , candidates , temperature , top_p ):
312+
313+ # adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
314+
315+ # Apply temperature
316+
317+ logits = logits [:, :- 1 ] / temperature
318+
319+ n_samples , n_tokens = logits .shape [0 ], logits .shape [1 ]
320+ logits = logits .view (n_samples * n_tokens , - 1 )
321+
322+ # Convert to probabilities (softmax)
323+ probs = F .softmax (logits , dim = - 1 )
324+ # Sort the probabilities
325+ sorted_logits , sorted_indices = torch .sort (probs , descending = True )
326+
327+ # Compute cumulative probabilities
328+ cum_probs = torch .cumsum (sorted_logits , dim = - 1 )
329+
330+ # Create mask for the top-p nucleus
331+ sorted_indices_to_remove = cum_probs > top_p
332+ sorted_indices_to_remove [..., 1 :] = sorted_indices_to_remove [..., :- 1 ].clone ()
333+ sorted_indices_to_remove [..., 0 ] = 0
334+
335+ indices_to_remove = sorted_indices_to_remove .scatter (dim = 1 , index = sorted_indices , src = sorted_indices_to_remove )
336+
337+
338+ # Remove low-probability tokens
339+ logits [indices_to_remove ] = float ('-inf' )
340+ # Sample from the remaining tokens
341+ sampled_tokens = torch .multinomial (F .softmax (logits , dim = - 1 ), 1 )
342+ sampled_tokens = sampled_tokens .view (n_samples , n_tokens )
343+ # Create a mask for selected tokens
344+ posterior_mask = (candidates [:, 1 :] == sampled_tokens ).int ()
345+
346+ return posterior_mask
347+
348+ def get_typical_posterior_mask (logits , candidates , temperature , posterior_threshold , posterior_alpha ):
349+ logits = logits [:, :- 1 ] / temperature
350+ n_samples , n_tokens = logits .shape [0 ], logits .shape [1 ]
351+ logits = logits .view (n_samples * n_tokens , - 1 )
352+ probs = F .softmax (logits , dim = - 1 )
353+ entropy = - torch .sum (
354+ probs * torch .log (probs + 1e-5 ), dim = - 1
355+ )
356+ threshold = torch .minimum (
357+ torch .ones_like (entropy ) * posterior_threshold ,
358+ torch .exp (- entropy ) * posterior_alpha ,
359+ )
360+ indices_to_remove = probs < threshold .unsqueeze (- 1 )
361+ logits [indices_to_remove ] = float ('-inf' )
362+ sampled_tokens = torch .multinomial (F .softmax (logits , dim = - 1 ), 1 )
363+ sampled_tokens = sampled_tokens .view (n_samples , n_tokens )
364+ posterior_mask = (candidates [:, 1 :] == sampled_tokens ).int ()
365+ return posterior_mask
366+
367+
274368
275369def evaluate_posterior (
276- logits , candidates , temperature , posterior_threshold , posterior_alpha
370+ logits , candidates , temperature , posterior_threshold = 0.3 , posterior_alpha = 0.09 , top_p = 0.8 , sampling = 'typical' , fast = True
277371):
278372 """
279373 Evaluate the posterior probabilities of the candidates based on the provided logits and choose the best candidate.
@@ -307,36 +401,64 @@ def evaluate_posterior(
307401 else :
308402 best_candidate = torch .argmax (candidates_accept_length ).to (torch .long )
309403 return best_candidate , accept_length
310- # Calculate posterior probabilities and thresholds for candidate selection
311- posterior_prob = torch .softmax (logits [:, :- 1 ] / temperature , dim = - 1 )
312- candidates_prob = torch .gather (
313- posterior_prob , dim = - 1 , index = candidates [:, 1 :].unsqueeze (- 1 )
314- ).squeeze (- 1 )
315- posterior_entropy = - torch .sum (
316- posterior_prob * torch .log (posterior_prob + 1e-5 ), dim = - 1
317- ) # torch.sum(torch.log(*)) is faster than torch.prod
318- threshold = torch .minimum (
319- torch .ones_like (posterior_entropy ) * posterior_threshold ,
320- torch .exp (- posterior_entropy ) * posterior_alpha ,
321- )
322- posterior_mask = candidates_prob > threshold
323- candidates_accept_length = (torch .cumprod (posterior_mask , dim = 1 )).sum (dim = 1 )
324-
325- # Choose the best candidate based on the evaluated posterior probabilities
326- accept_length = candidates_accept_length .max ()
327- if accept_length == 0 :
328- # If no candidates are accepted, just choose the first one
329- best_candidate = torch .tensor (0 , dtype = torch .long , device = candidates .device )
404+
405+ if sampling == 'typical' :
406+ if fast :
407+ posterior_prob = torch .softmax (logits [:, :- 1 ] / temperature , dim = - 1 )
408+ candidates_prob = torch .gather (
409+ posterior_prob , dim = - 1 , index = candidates [:, 1 :].unsqueeze (- 1 )
410+ ).squeeze (- 1 )
411+ posterior_entropy = - torch .sum (
412+ posterior_prob * torch .log (posterior_prob + 1e-5 ), dim = - 1
413+ ) # torch.sum(torch.log(*)) is faster than torch.prod
414+ threshold = torch .minimum (
415+ torch .ones_like (posterior_entropy ) * posterior_threshold ,
416+ torch .exp (- posterior_entropy ) * posterior_alpha ,
417+ )
418+ posterior_mask = candidates_prob > threshold
419+ candidates_accept_length = (torch .cumprod (posterior_mask , dim = 1 )).sum (dim = 1 )
420+
421+ # Choose the best candidate based on the evaluated posterior probabilities
422+ accept_length = candidates_accept_length .max ()
423+ if accept_length == 0 :
424+ # If no candidates are accepted, just choose the first one
425+ best_candidate = torch .tensor (0 , dtype = torch .long , device = candidates .device )
426+ else :
427+ best_candidates = torch .where (candidates_accept_length == accept_length )[0 ]
428+ # Accept the best one according to likelihood
429+ likelihood = torch .sum (
430+ torch .log (candidates_prob [best_candidates , :accept_length ]), dim = - 1
431+ )
432+ best_candidate = best_candidates [torch .argmax (likelihood )]
433+ return best_candidate , accept_length
434+ # Calculate posterior probabilities and thresholds for candidate selection
435+ posterior_mask = get_typical_posterior_mask (logits , candidates , temperature , posterior_threshold , posterior_alpha , fast )
436+ candidates_accept_length = (torch .cumprod (posterior_mask , dim = 1 )).sum (dim = 1 )
437+ # Choose the best candidate based on the evaluated posterior probabilities
438+ accept_length = candidates_accept_length .max ()
439+
440+ if accept_length == 0 :
441+ # If no candidates are accepted, just choose the first one
442+ best_candidate = torch .tensor (0 , dtype = torch .long , device = candidates .device )
443+ else :
444+ best_candidate = torch .argmax (candidates_accept_length ).to (torch .long )
445+ # Accept the best one according to likelihood
446+ return best_candidate , accept_length
447+
448+ if sampling == 'nucleus' :
449+ assert top_p < 1.0 + 1e-6 , "top_p should between 0 and 1"
450+ posterior_mask = get_nucleus_posterior_mask (logits , candidates , temperature , top_p )
451+ candidates_accept_length = (torch .cumprod (posterior_mask , dim = 1 )).sum (dim = 1 )
452+ accept_length = candidates_accept_length .max ()
453+ # Choose the best candidate
454+ if accept_length == 0 :
455+ # Default to the first candidate if none are accepted
456+ best_candidate = torch .tensor (0 , dtype = torch .long , device = candidates .device )
457+ else :
458+ best_candidate = torch .argmax (candidates_accept_length ).to (torch .long )
459+ return best_candidate , accept_length
330460 else :
331- best_candidates = torch .where (candidates_accept_length == accept_length )[0 ]
332- # Accept the best one according to likelihood
333- likelihood = torch .sum (
334- torch .log (candidates_prob [best_candidates , :accept_length ]), dim = - 1
335- )
336- best_candidate = best_candidates [torch .argmax (likelihood )]
337- return best_candidate , accept_length
338-
339-
461+ raise NotImplementedError
340462def update_inference_inputs (
341463 input_ids ,
342464 candidates ,
0 commit comments