We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 444bf42 commit 289df63Copy full SHA for 289df63
moonshine/run_eval.py
@@ -50,7 +50,7 @@ def benchmark(batch, min_new_tokens=None):
50
output_mask = torch.arange(pred_ids.shape[-1]).repeat((pred_ids.shape[0], 1)).to(args.device)
51
output_mask = output_mask > max_new_tokens
52
53
- eot_token = 2
+ eot_token = model.config.eos_token_id
54
pred_ids.masked_fill(output_mask, eot_token)
55
56
# 3.2 Convert token ids to text transcription
0 commit comments