Skip to content

Commit 2b89814

Browse files
committed
update cli
1 parent ac75468 commit 2b89814

File tree

3 files changed

+100
-22
lines changed

3 files changed

+100
-22
lines changed

medusa/inference/cli.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ def main(args):
3636
try:
3737
model = MedusaModel.from_pretrained(
3838
args.model,
39-
args.base_model,
40-
medusa_num_heads = 4,
4139
torch_dtype=torch.float16,
4240
low_cpu_mem_usage=True,
4341
device_map="auto",
@@ -48,7 +46,7 @@ def main(args):
4846
conv = None
4947

5048
def new_chat():
51-
return get_conversation_template("vicuna")
49+
return get_conversation_template(args.model)
5250

5351
def reload_conv(conv):
5452
"""
@@ -187,7 +185,6 @@ def reload_conv(conv):
187185
if __name__ == "__main__":
188186
parser = argparse.ArgumentParser()
189187
parser.add_argument("--model", type=str, required=True, help="Model name or path.")
190-
parser.add_argument("--base-model", type=str, default=None, help="Base model name or path.")
191188
parser.add_argument(
192189
"--load-in-8bit", action="store_true", help="Use 8-bit quantization"
193190
)

medusa/model/medusa_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def medusa_generate(
188188
posterior_threshold=0.09, # threshold validation of Medusa output
189189
# another threshold hyperparameter, recommended to be sqrt(posterior_threshold)
190190
posterior_alpha=0.3,
191+
top_p=0.8,
192+
sampling = 'typical',
193+
fast = True
191194
):
192195
"""
193196
Args:
@@ -197,6 +200,9 @@ def medusa_generate(
197200
medusa_choices (list, optional): A list of integers indicating the number of choices for each Medusa head.
198201
posterior_threshold (float, optional): Threshold for posterior validation.
199202
posterior_alpha (float, optional): Another threshold hyperparameter, recommended to be sqrt(posterior_threshold).
203+
top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
204+
sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
205+
fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
200206
Returns:
201207
torch.Tensor: Output token IDs.
202208
@@ -253,6 +259,12 @@ def medusa_generate(
253259
logits,
254260
medusa_buffers["tree_indices"],
255261
medusa_buffers["retrieve_indices"],
262+
temperature=temperature,
263+
posterior_alpha=posterior_alpha,
264+
posterior_threshold=posterior_threshold,
265+
top_p=top_p,
266+
sampling=sampling,
267+
fast=fast,
256268
)
257269

258270
# Use tree attention to verify the candidates and get predictions
@@ -267,7 +279,7 @@ def medusa_generate(
267279

268280
# Evaluate the posterior of the candidates to select the accepted candidate prefix
269281
best_candidate, accept_length = evaluate_posterior(
270-
logits, candidates, temperature, posterior_threshold, posterior_alpha
282+
logits, candidates, temperature, posterior_threshold, posterior_alpha, top_p=top_p, sampling=sampling, fast=fast
271283
)
272284

273285
# Update the input_ids and logits

medusa/model/utils.py

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33

4-
TOPK=10 # topk for sparse tree
4+
TOPK=10 # topk for sparse tree (10 is a placeholder and it is sufficient)
55

66
def pad_path(path, length, pad_value=-2):
77
"""
@@ -168,7 +168,7 @@ def reset_medusa_mode(
168168
- past_key_values (list of torch.Tensor): Contains past hidden states and past attention values.
169169
170170
Returns:
171-
- past_key_values (list of torch.Tensor): Updated past hidden states and past attention values with reset lengths.
171+
- None
172172
"""
173173
model.base_model.model.medusa_mask = None
174174
model.base_model.model.medusa_mode = None
@@ -194,8 +194,25 @@ def reset_past_key_values(passed_key_values):
194194
return passed_key_values
195195

196196
def get_nucleus_one_token(logit, temperature, top_p):
197-
# input [nxC]
198-
logit = logit[:, :-1] / temperature
197+
"""
198+
Performs token sampling based on the nucleus (top-p) sampling method.
199+
200+
This function selects a token from a given logit distribution using the nucleus sampling strategy.
201+
It allows for more controlled and diverse generation compared to traditional top-k sampling.
202+
203+
Args:
204+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor (BxC).
205+
temperature (float): A temperature parameter to control the randomness in sampling.
206+
Higher values increase diversity, lower values make selections more deterministic.
207+
top_p (float): The cumulative probability threshold for nucleus sampling.
208+
It controls the size of the set of high-probability tokens to consider for sampling.
209+
210+
Returns:
211+
torch.Tensor: A tensor containing the indices of the sampled tokens.
212+
"""
213+
if top_p >= 1:
214+
return torch.multinomial(F.softmax(logit / temperature, dim=-1), 1)
215+
logit = logit / temperature
199216
probs = torch.softmax(logit, dim=-1)
200217
sorted_logits, sorted_indices = torch.sort(probs, descending=True)
201218
cum_probs = torch.cumsum(sorted_logits, dim=-1)
@@ -208,8 +225,23 @@ def get_nucleus_one_token(logit, temperature, top_p):
208225
return sampled_tokens
209226

210227
def get_typical_one_token(logit, temperature, posterior_threshold, posterior_alpha):
211-
# input [nxC]
212-
logit = logit[:, :-1] / temperature
228+
"""
229+
Implements token sampling based on the typical sampling method.
230+
231+
This function selects a token from a given logit distribution using the typical sampling strategy,
232+
aiming to balance between diversity and likelihood in a more nuanced way compared to traditional methods.
233+
234+
Args:
235+
logit (torch.Tensor): The logits from a language model output, expected to be a 2D tensor.
236+
temperature (float): A parameter to control the randomness in sampling.
237+
Higher values increase diversity, lower values make selections more deterministic.
238+
posterior_threshold (float): A threshold to decide the lower bound of probabilities to be considered for sampling.
239+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
240+
241+
Returns:
242+
torch.Tensor: A tensor containing the indices of the sampled tokens.
243+
"""
244+
logit = logit / temperature
213245
probs = torch.softmax(logit, dim=-1)
214246
entropy = -torch.sum(
215247
probs * torch.log(probs + 1e-5), dim=-1
@@ -228,15 +260,22 @@ def generate_candidates(medusa_logits, logits, tree_indices, retrieve_indices, t
228260
Generate candidates based on provided logits and indices.
229261
230262
Parameters:
231-
- medusa_logits (torch.Tensor): Logits associated with the Medusa structure.
232-
- logits (torch.Tensor): Original logits.
233-
- tree_indices (list or torch.Tensor): Indices associated with a tree structure.
234-
- retrieve_indices (list or torch.Tensor): Indices for retrieving candidates.
235-
263+
- medusa_logits (torch.Tensor): Logits from a specialized Medusa structure, aiding in candidate selection.
264+
- logits (torch.Tensor): Standard logits from a language model.
265+
- tree_indices (list or torch.Tensor): Indices representing a tree structure, used for mapping candidates.
266+
- retrieve_indices (list or torch.Tensor): Indices for extracting specific candidate tokens.
267+
- temperature (float, optional): Controls the diversity of the sampling process. Defaults to 0.
268+
- posterior_threshold (float, optional): Threshold for typical sampling. Defaults to 0.3.
269+
- posterior_alpha (float, optional): Scaling factor for the entropy-based threshold in typical sampling. Defaults to 0.09.
270+
- top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
271+
- sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
272+
- fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
273+
236274
Returns:
237-
- tuple: Returns cartesian candidates and tree candidates.
275+
- tuple (torch.Tensor, torch.Tensor): A tuple containing two sets of candidates:
276+
1. Cartesian candidates derived from the combined original and Medusa logits.
277+
2. Tree candidates mapped from the Cartesian candidates using tree indices.
238278
"""
239-
240279
# Greedy decoding: Select the most probable candidate from the original logits.
241280
if temperature == 0 or fast:
242281
candidates_logit = torch.argmax(logits[:, -1]).unsqueeze(0)
@@ -309,16 +348,33 @@ def tree_decoding(
309348
return medusa_logits, logits, outputs
310349

311350
def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
351+
"""
352+
Generates a posterior mask for token candidates using nucleus (top-p) sampling.
353+
354+
This function applies nucleus sampling to a set of logits, and then generates a mask indicating
355+
which candidate tokens are selected. It adapts the sampling strategy to accommodate for
356+
temperature scaling and cumulative probability thresholding.
357+
358+
Args:
359+
logits (torch.Tensor): A tensor of logits from a language model output.
360+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
361+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
362+
top_p (float): The cumulative probability threshold for nucleus sampling.
312363
364+
Returns:
365+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
366+
"""
313367
# adapted from https://github.com/huggingface/transformers/blob/18a879f47576822aa1a5c49aecb27d89bfa5fa69/examples/run_generation.py#L79
314368

315369
# Apply temperature
316-
317370
logits = logits[:, :-1] / temperature
318-
319371
n_samples, n_tokens = logits.shape[0], logits.shape[1]
320372
logits = logits.view(n_samples*n_tokens, -1)
321-
373+
if top_p >= 1:
374+
sampled_tokens = torch.multinomial(F.softmax(logits, dim=-1), 1)
375+
sampled_tokens = sampled_tokens.view(n_samples, n_tokens)
376+
posterior_mask = (candidates[:, 1:] == sampled_tokens).int()
377+
return posterior_mask
322378
# Convert to probabilities (softmax)
323379
probs = F.softmax(logits, dim=-1)
324380
# Sort the probabilities
@@ -346,6 +402,17 @@ def get_nucleus_posterior_mask(logits, candidates, temperature, top_p):
346402
return posterior_mask
347403

348404
def get_typical_posterior_mask(logits, candidates, temperature, posterior_threshold, posterior_alpha):
405+
"""
406+
Args:
407+
logits (torch.Tensor): A tensor of logits from a language model output.
408+
candidates (torch.Tensor): A tensor of candidate tokens to compare against sampled tokens.
409+
temperature (float): A parameter to scale the logits, controlling randomness in sampling.
410+
posterior_threshold (float): The minimum threshold for probabilities to be considered in sampling.
411+
posterior_alpha (float): A scaling factor applied to the entropy-based adaptive threshold.
412+
413+
Returns:
414+
torch.Tensor: A posterior mask indicating which candidate tokens match the sampled tokens.
415+
"""
349416
logits = logits[:, :-1] / temperature
350417
n_samples, n_tokens = logits.shape[0], logits.shape[1]
351418
logits = logits.view(n_samples*n_tokens, -1)
@@ -381,7 +448,9 @@ def evaluate_posterior(
381448
- temperature (float): Softmax temperature for probability scaling. A value of 0 indicates greedy decoding.
382449
- posterior_threshold (float): Threshold for posterior probability.
383450
- posterior_alpha (float): Scaling factor for the threshold.
384-
451+
- top_p (float, optional): Cumulative probability threshold for nucleus sampling. Defaults to 0.8.
452+
- sampling (str, optional): Defines the sampling strategy ('typical' or 'nucleus'). Defaults to 'typical'.
453+
- fast (bool, optional): If True, enables faster, deterministic decoding for typical sampling. Defaults to False.
385454
Returns:
386455
- best_candidate (torch.Tensor): Index of the chosen best candidate.
387456
- accept_length (int): Length of the accepted candidate sequence.

0 commit comments

Comments
 (0)