Skip to content

Commit 0061b13

Browse files
committed
fix bug
1 parent 077977a commit 0061b13

File tree

3 files changed

+14
-6
lines changed

3 files changed

+14
-6
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,6 @@ test_medusa*
169169

170170
# test
171171
notebooks/test*.ipynb
172-
notebooks/*.pdf
172+
notebooks/*.pdf
173+
*.sh
174+
llm_judge/data/mt_bench_test

llm_judge/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ We report the 3 times running results of the Medusa X Vicuna v1.3 7/13/33b on a
1313

1414

1515
```
16-
export CUDA_VISIBLE_DEVICES= 0 # set the GPU id
16+
export CUDA_VISIBLE_DEVICES=0 # set the GPU id
1717
python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-7b-v1.3 --model-id medusa-vicuna-7b-v1.3-0
1818
python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-13b-v1.3 --model-id medusa-vicuna-13b-v1.3-0
1919
python gen_model_answer_medusa.py --model-path FasterDecoding/medusa-vicuna-33b-v1.3 --model-id medusa-vicuna-33b-v1.3-0

llm_judge/gen_model_answer_medusa.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from medusa.model.utils import *
2323
from medusa.model.medusa_model import MedusaModel
2424
from medusa.model.kv_cache import initialize_past_key_values
25-
from medusa.model.medusa_choices import medusa_choices
25+
from medusa.model.medusa_choices import *
2626

2727
def medusa_forward(input_ids, model, tokenizer, medusa_choices, temperature, posterior_threshold, posterior_alpha, max_steps = 512):
2828
assert input_ids.shape[0] == 1, "Only support batch size 1 for now!!"
@@ -191,7 +191,7 @@ def get_model_answers(
191191
tokenizer = model.get_tokenizer()
192192

193193
model.eval()
194-
print('Check model state:',model.training)
194+
print('Check model training state:',model.training)
195195

196196
cuda_visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES')
197197
print('CUDA VISIBLE DEVICES:', cuda_visible_devices)
@@ -456,14 +456,20 @@ def reorg_answer_file(answer_file):
456456
help="The posterior alpha for medusa sampling.",
457457
)
458458

459+
parser.add_argument(
460+
"--medusa-choices",
461+
type=str,
462+
default="mc_sim_7b_63",
463+
help="The medusa choices for medusa sampling.",
464+
)
459465

460466

461467

462468

463469
args = parser.parse_args()
464470

465471
args.model_id = args.model_id+"-temperature-"+str(args.temperature)+"-posterior_threshold-"+str(args.posterior_threshold)+"-posterior_alpha-"+str(args.posterior_alpha)
466-
472+
args.medusa_choices = eval(args.medusa_choices)
467473
if args.num_gpus_total // args.num_gpus_per_model > 1:
468474
import ray
469475

@@ -493,7 +499,7 @@ def reorg_answer_file(answer_file):
493499
args.temperature,
494500
args.posterior_threshold,
495501
args.posterior_alpha,
496-
medusa_choices,
502+
args.medusa_choices,
497503
)
498504

499505
reorg_answer_file(answer_file)

0 commit comments

Comments
 (0)