Skip to content

Commit e3de627

Browse files
committed
fix
1 parent 31127b8 commit e3de627

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

examples/simple-tts/simple-tts.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -586,12 +586,11 @@ int main(int argc, char ** argv) {
586586
return 1;
587587
}
588588

589-
std::vector<llama_sampler> samplers(n_parallel);
589+
std::vector<llama_sampler *> smpl(n_parallel);
590590
for (int i = 0; i < n_parallel; ++i) {
591-
llama_sampler * smpl = &samplers[i];
592-
smpl = llama_sampler_chain_init(llama_sampler_chain_default_params());
593-
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
594-
llama_sampler_chain_add(smpl, llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
591+
smpl[i] = llama_sampler_chain_init(llama_sampler_chain_default_params());
592+
llama_sampler_chain_add(smpl[i], llama_sampler_init_greedy());
593+
llama_sampler_chain_add(smpl[i], llama_sampler_init_dist(LLAMA_DEFAULT_SEED));
595594
}
596595

597596
outetts_version tts_version = get_tts_version(model);
@@ -664,7 +663,7 @@ int main(int argc, char ** argv) {
664663
continue;
665664
}
666665

667-
llama_token new_token_id = llama_sampler_sample(&samplers[i], ctx, i_batch[i]);
666+
llama_token new_token_id = llama_sampler_sample(smpl[i], ctx, i_batch[i]);
668667

669668
//guide tokens help prevent hallucinations by forcing the TTS to use the correct word
670669
if (!guide_tokens.empty() && next_token_uses_guide_token && !llama_vocab_is_control(vocab, new_token_id) && !llama_vocab_is_eog(vocab, new_token_id)) {
@@ -676,7 +675,7 @@ int main(int argc, char ** argv) {
676675
//this is the token id that always precedes a new word
677676
next_token_uses_guide_token = (new_token_id == 198);
678677

679-
llama_sampler_accept(&samplers[i], new_token_id);
678+
llama_sampler_accept(smpl[i], new_token_id);
680679

681680
codes.push_back(new_token_id);
682681

0 commit comments

Comments
 (0)