11import torch
22import torch .nn .functional as F
33
4- TOPK = 10 # topk for sparse tree
4+ TOPK = 10 # topk for sparse tree (10 is a placeholder and it is sufficient)
55
66def pad_path (path , length , pad_value = - 2 ):
77 """
@@ -168,7 +168,7 @@ def reset_medusa_mode(
168168 - past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.
169169
170170 Returns:
171- - past_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
171+ - None
172172 """
173173 model .base_model .model .medusa_mask = None
174174 model .base_model .model .medusa_mode = None
@@ -194,8 +194,25 @@ def reset_past_key_values(passed_key_values):
194194 return passed_key_values
195195
196196def get_nucleus_one_token (logit , temperature , top_p ):
197- # input [nxC]
198- logit = logit [:, :- 1 ] / temperature
197+ """
198+ Performs token sampling based on the nucleus (top-p) sampling method.
199+
200+ This function selects a token from a given logit distribution using the nucleus sampling strategy.
201+ It allows for more controlled and diverse generation compared to traditional top-k sampling.
202+
203+ Args:
204+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
205+ temperature (float): A temperature parameter to control the randomness in sampling.
206+ Higher values increase diversity, lower values make selections more deterministic.
207+ top_p (float): The cumulative probability threshold for nucleus sampling.
208+ It controls the size of the set of high-probability tokens to consider for sampling.
209+
210+ Returns:
211+ torch.Tensor: A tensor containing the indices of the sampled tokens.
212+ """
213+ if top_p >= 1 :
214+ return torch .multinomial (F .softmax (logit / temperature , dim = - 1 ), 1 )
215+ logit = logit / temperature
199216 probs = torch .softmax (logit , dim = - 1 )
200217 sorted_logits , sorted_indices = torch .sort (probs , descending = True )
201218 cum_probs = torch .cumsum (sorted_logits , dim = - 1 )
@@ -208,8 +225,23 @@ def get_nucleus_one_token(logit, temperature, top_p):
208225 return sampled_tokens
209226
210227def get_typical_one_token (logit , temperature , posterior_threshold , posterior_alpha ):
211- # input [nxC]
212- logit = logit [:, :- 1 ] / temperature
228+ """
229+ Implements token sampling based on the typical sampling method.
230+
231+ This function selects a token from a given logit distribution using the typical sampling strategy,
232+ aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
233+
234+ Args:
235+ logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
236+ temperature (float): A parameter to control the randomness in sampling.
237+ Higher values increase diversity, lower values make selections more deterministic.
238+ posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
239+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
240+
241+ Returns:
242+ torch.Tensor: A tensor containing the indices of the sampled tokens.
243+ """
244+ logit = logit / temperature
213245 probs = torch .softmax (logit , dim = - 1 )
214246 entropy = - torch .sum (
215247 probs * torch .log (probs + 1e-5 ), dim = - 1
@@ -228,15 +260,22 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, t
228260 Generate candidates based on provided logits and indices.
229261
230262 Parameters:
231- - medusa_logits (torch.Tensor): Logits associated with the Medusa structure.
232- - logits (torch.Tensor): Original logits.
233- - tree_indices (list or torch.Tensor): Indices associated with a tree structure.
234- - retrieve_indices (list or torch.Tensor): Indices for retrieving candidates.
235-
263+ - medusa_logits (torch.Tensor): Logits from a specialized Medusa structure, aiding in candidate selection.
264+ - logits (torch.Tensor): Standard logits from a language model.
265+ - tree_indices (list or torch.Tensor): Indices representing a tree structure, used for mapping candidates.
266+ - retrieve_indices (list or torch.Tensor): Indices for extracting specific candidate tokens.
267+ - temperature (float, optional): Controls the diversity of the sampling process. Defaults to 0.
268+ - posterior_threshold (float, optional): Threshold for typical sampling. Defaults to 0.3.
269+ - posterior_alpha (float, optional): Scaling factor for the entropy-based threshold in typical sampling. Defaults to 0.09.
270+ - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
271+ - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
272+ - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
273+
236274 Returns:
237- - tuple: Returns cartesian candidates and tree candidates.
275+ - tuple (torch.Tensor, torch.Tensor): A tuple containing two sets of candidates:
276+ 1. Cartesian candidates derived from the combined original and Medusa logits.
277+ 2. Tree candidates mapped from the Cartesian candidates using tree indices.
238278 """
239-
240279 # Greedy decoding: Select the most probable candidate from the original logits.
241280 if temperature == 0 or fast :
242281 candidates_logit = torch .argmax (logits [:, - 1 ]).unsqueeze (0 )
@@ -309,16 +348,33 @@ def tree_decoding(
309348 return medusa_logits , logits , outputs
310349
311350def get_nucleus_posterior_mask (logits , candidates , temperature , top_p ):
351+ """
352+ Generates a posterior mask for token candidates using nucleus (top-p) sampling.
353+
354+ This function applies nucleus sampling to a set of logits, and then generates a mask indicating
355+ which candidate tokens are selected. It adapts the sampling strategy to accommodate for
356+ temperature scaling and cumulative probability thresholding.
357+
358+ Args:
359+ logits (torch.Tensor): A tensor of logits from a language model output.
360+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
361+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
362+ top_p (float): The cumulative probability threshold for nucleus sampling.
312363
364+ Returns:
365+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
366+ """
313367 # adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
314368
315369 # Apply temperature
316-
317370 logits = logits [:, :- 1 ] / temperature
318-
319371 n_samples , n_tokens = logits .shape [0 ], logits .shape [1 ]
320372 logits = logits .view (n_samples * n_tokens , - 1 )
321-
373+ if top_p >= 1 :
374+ sampled_tokens = torch .multinomial (F .softmax (logits , dim = - 1 ), 1 )
375+ sampled_tokens = sampled_tokens .view (n_samples , n_tokens )
376+ posterior_mask = (candidates [:, 1 :] == sampled_tokens ).int ()
377+ return posterior_mask
322378 # Convert to probabilities (softmax)
323379 probs = F .softmax (logits , dim = - 1 )
324380 # Sort the probabilities
@@ -346,6 +402,17 @@ def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
346402 return posterior_mask
347403
348404def get_typical_posterior_mask (logits , candidates , temperature , posterior_threshold , posterior_alpha ):
405+ """
406+ Args:
407+ logits (torch.Tensor): A tensor of logits from a language model output.
408+ candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
409+ temperature (float): A parameter to scale the logits, controlling randomness in sampling.
410+ posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
411+ posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
412+
413+ Returns:
414+ torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
415+ """
349416 logits = logits [:, :- 1 ] / temperature
350417 n_samples , n_tokens = logits .shape [0 ], logits .shape [1 ]
351418 logits = logits .view (n_samples * n_tokens , - 1 )
@@ -381,7 +448,9 @@ def evaluate_posterior(
381448 - temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
382449 - posterior_threshold (float): Threshold for posterior probability.
383450 - posterior_alpha (float): Scaling factor for the threshold.
384-
451+ - top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
452+ - sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
453+ - fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
385454 Returns:
386455 - best_candidate (torch.Tensor): Index of the chosen best candidate.
387456 - accept_length (int): Length of the accepted candidate sequence.
0 commit comments