Skip to content

Commit eae5f0e

Browse files
committed
add mimi_model::transpose_input
1 parent 891273c commit eae5f0e

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

examples/tts/mimi-model.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ std::vector<float> mimi_model::decode_frame(const std::vector<int> & codes, int
617617
int n_pos = -1;
618618
int n_codes = codes.size();
619619
int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components;
620-
GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiple of n_codes_per_embd");
620+
GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd");
621621

622622
ctx->build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf) {
623623
ggml_tensor * inp_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_codes);
@@ -661,14 +661,6 @@ std::vector<float> mimi_model::decode_frame(const std::vector<int> & codes, int
661661
ctx->set_tensor_data("pos_dec", pos_data.data());
662662

663663
// code data
664-
/*std::vector<int> codes_t(n_codes_per_embd * n_codes);
665-
for (int i = 0; i < n_codes / n_codes_per_embd; i++) {
666-
for (int j = 0; j < n_codes_per_embd; j++) {
667-
int src_idx = i * n_codes_per_embd + j;
668-
int dst_idx = j * (n_codes / n_codes_per_embd) + i;
669-
codes_t[dst_idx] = codes[src_idx];
670-
}
671-
}*/
672664
ctx->set_tensor_data("inp_dec", codes.data());
673665

674666
ctx->compute();
@@ -715,6 +707,23 @@ std::vector<float> mimi_model::decode(const std::vector<int> & codes) {
715707
return output;
716708
}
717709

710+
std::vector<int> mimi_model::transpose_input(const std::vector<int> & codes) {
711+
int n_codes = codes.size();
712+
int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components;
713+
GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd");
714+
715+
std::vector<int> codes_T(n_codes_per_embd * n_codes);
716+
for (int i = 0; i < n_codes / n_codes_per_embd; i++) {
717+
for (int j = 0; j < n_codes_per_embd; j++) {
718+
int src_idx = i * n_codes_per_embd + j;
719+
int dst_idx = j * (n_codes / n_codes_per_embd) + i;
720+
codes_T[dst_idx] = codes[src_idx];
721+
}
722+
}
723+
724+
return codes_T;
725+
}
726+
718727
int mimi_model::get_sample_rate() const {
719728
return mimi_config.sample_rate;
720729
}

examples/tts/mimi-model.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ struct mimi_model {
2222

2323
int get_sample_rate() const;
2424

25+
// transpose layout:
26+
// - from: (1 semantic code followed by 31 acoustic codes) repeast N times
27+
// - to: N semantic codes followed by (N*31) acoustic codes
28+
std::vector<int> transpose_input(const std::vector<int> & codes);
29+
2530
// layout of codes: N semantic codes followed by (N*31) acoustic codes
2631
std::vector<float> decode(const std::vector<int> & codes);
2732

0 commit comments

Comments
 (0)