diff --git a/.gitignore b/.gitignore index 2c67ad7f7c609..41fe1f31271d2 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ examples/server/*.gz.hpp !examples/*/*/*.kts !examples/sycl/*.bat !examples/sycl/*.sh +/*.wav # Server Web UI temporary files node_modules diff --git a/common/common.cpp b/common/common.cpp index 18ffb4e738aee..30870980a148d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2055,3 +2055,31 @@ common_grammar_trigger common_grammar_trigger::from_json(const json & in) { } return out; } + +// +// Audio utils +// + +bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { + std::ofstream file(fname, std::ios::binary); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); + return false; + } + + wav_header header; + header.sample_rate = sample_rate; + header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); + header.block_align = header.num_channels * (header.bits_per_sample / 8); + header.data_size = data.size() * (header.bits_per_sample / 8); + header.chunk_size = 36 + header.data_size; + + file.write(reinterpret_cast(&header), sizeof(header)); + + for (const auto & sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); + } + + return file.good(); +} diff --git a/common/common.h b/common/common.h index 1c0f199774976..0c67693149285 100644 --- a/common/common.h +++ b/common/common.h @@ -683,3 +683,25 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; } + +// +// Audio utils +// + +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b31..371c3bbf7434d 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -3,3 +3,10 @@ add_executable(${TARGET} tts.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-mimi) +add_executable(${TARGET} mimi.cpp mimi-model.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +# for using C++ designated initializers, TODO: can be changed back to C++17 in the future +target_compile_features(${TARGET} PRIVATE cxx_std_20) diff --git a/examples/tts/README-mimi.md b/examples/tts/README-mimi.md new file mode 100644 index 0000000000000..6576a118291ad --- /dev/null +++ b/examples/tts/README-mimi.md @@ -0,0 +1,50 @@ +# llama.cpp/example/mimi + +This demonstrates running [Kyutai's Mimi](https://huggingface.co/kyutai/mimi) model via GGML. + +## Quickstart + +Convert model to GGUF (no need to download, the script will automatically download the `safetensors` file) + +```sh +python examples/tts/convert_mimi_to_gguf.py + +# output file: kyutai-mimi.gguf + +# optionally, use q8_0 quantization for faster speed +python examples/tts/convert_mimi_to_gguf.py --outtype q8_0 +``` + +Then compile, run it: + +```sh +cmake --build build -j --target llama-mimi + +./build/bin/llama-mimi kyutai-mimi.gguf codes.txt + +# output: output.wav + +# alternatively, use "dummy1" to get a "wah hello there" sample output file +./build/bin/llama-mimi kyutai-mimi.gguf dummy1 +``` + +Example of code file (one code per line): + +``` +1263 +1597 +1596 +1477 +1540 +1720 +1433 +118 +1066 +1968 +1096 +232 +418 +566 +1653 +2010 +``` diff --git a/examples/tts/convert_mimi_to_gguf.py b/examples/tts/convert_mimi_to_gguf.py new file mode 100644 index 0000000000000..5dce72a398a91 --- /dev/null +++ b/examples/tts/convert_mimi_to_gguf.py @@ -0,0 +1,191 @@ +import gguf +import argparse +import logging +import torch +from typing import Union +from pathlib import Path +from torch import Tensor +from transformers import MimiModel, PreTrainedModel + +logger = logging.getLogger("mimi") + + +class MimiModelConverter: + mimi_model: PreTrainedModel + gguf_writer: gguf.GGUFWriter + fname_out: Path + ftype: gguf.LlamaFileType + + def __init__(self, + pretrained_model_name_or_path: Union[Path, str], + fname_out: Path, + ftype: gguf.LlamaFileType, + is_big_endian: bool,): + self.mimi_model = MimiModel.from_pretrained(pretrained_model_name_or_path) + self.fname_out = fname_out + self.ftype = ftype + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.gguf_writer = gguf.GGUFWriter( + path=None, + arch="if you see this, you are using the wrong file", + endianess=endianess) + + assert self.mimi_model.config.architectures[0] == "MimiModel" + + # load tensors + for name, data_torch in self.mimi_model.state_dict().items(): + # convert any unsupported data types to float32 + old_dtype = data_torch.dtype + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + self.add_tensor(name, data_torch, old_dtype) + + def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype): + is_1d = len(data_torch.shape) == 1 + is_bias = ".bias" in name + can_quantize = not is_1d and not is_bias + data_qtype = gguf.GGMLQuantizationType.F32 + + n_head = self.mimi_model.config.num_attention_heads + n_kv_head = self.mimi_model.config.num_key_value_heads + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = self.undo_permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = self.undo_permute(data_torch, n_head, n_kv_head) + + # process codebook + if ".codebook.initialized" in name: + # "initialized" tensor + state_dict = self.mimi_model.state_dict() + embed_sum = state_dict[name.replace(".initialized", ".embed_sum")] + cluster_usage = state_dict[name.replace(".initialized", ".cluster_usage")] + # see modeling_mimi.py --> MimiEuclideanCodebook + data_torch = embed_sum / cluster_usage.clamp(min=self.mimi_model.config.norm_eps)[:, None] + name = name.replace(".initialized", "") + + # ignore processed tensors + if ".cluster_usage" in name or ".embed_sum" in name: + return + + # transpose some tensors + if ".conv.bias" in name: + data_torch = data_torch.view((1, data_torch.shape[0])) + data_torch = data_torch.transpose(0, 1) + + # change view 3d to 2d + if "quantizer" in name and "_proj." in name: + assert data_torch.shape[2] == 1 + data_torch = data_torch.view((data_torch.shape[0], data_torch.shape[1])) + + # shorten name, otherwise it will be too long for ggml to read + name = name.replace("_residual_vector_quantizer", "_rvq") + + if can_quantize: + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + else: + raise ValueError(f"Unsupported file type: {self.ftype}") + + # Conv kernels are always F16 + if ".conv.weight" in name: + data_qtype = gguf.GGMLQuantizationType.F16 + + data = data_torch.numpy() + + try: + data = gguf.quants.quantize(data, data_qtype) + except Exception as e: + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) + + def write(self): + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + + @staticmethod + def undo_permute(weights: Tensor, n_head: int, n_head_kv: int): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Mimi safetensors model to GGUF",) + parser.add_argument( + "--outfile", type=Path, default="kyutai-mimi.gguf", + help="path to write to", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="directory or model ID containing model file (if model ID is specified, download from Hugging Face hub)", + nargs="?", + default="kyutai/mimi", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + + args = parser.parse_args() + if args.model is None: + parser.error("the following arguments are required: model") + return args + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + } + + logger.info(f"Loading model: {dir_model}") + + with torch.inference_mode(): + converter = MimiModelConverter( + pretrained_model_name_or_path=dir_model, + fname_out=args.outfile, + ftype=ftype_map[args.outtype], + is_big_endian=args.bigendian, + ) + converter.write() + + +if __name__ == '__main__': + main() + diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp new file mode 100644 index 0000000000000..427aeff8658bf --- /dev/null +++ b/examples/tts/mimi-model.cpp @@ -0,0 +1,733 @@ +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "gguf.h" + +#include "common.h" +#include "mimi-model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * Implementation of Kyutai's Mimi model using GGML. + * Based on this research: https://github.com/ngxson/ggml-easy/blob/master/demo/kyutai-mimi.cpp + * + * NOTE: only decoder is working for now. + * + * Background: + * - The audio codes can be generated using any Mimi-based model, for example: Moshi, Hibiki, Sesame, etc + * - Audio codes must be in the order: N semantic codes followed by (N*31) acoustic codes + * (In other words, input matrix has shape 32 cols x N rows) + * + * How it works? + * 1. Audio code passed to RVQ (mimi_residual_vector_quantizer) to get the latent code + * 2. The latent code is passed to a mimi_conv_transpose_1d (depthwise) to upscale + * 3. The upscaled code is passed to transformer, it converts N frames to N frames + * 4. The output embeddings is then passed to SEANet (mimi_encoder_decoder) to get the final waveform + * 5. Waveform is written to a file + */ + +// copied from https://huggingface.co/kyutai/mimi/blob/main/config.json +struct mimi_config_t { + bool causal = true; + int sample_rate = 24000; + int max_position_embeddings = 8000; + int num_hidden_layers = 8; + int n_embd = 512; + int n_ffn = 2048; + int n_head = 8; + int n_head_kv = 8; + int n_rot = 64; + float norm_eps = 1e-5; + float rope_theta = 10000.0f; + int sliding_window = 250; + std::array upsampling_ratio = {8, 6, 5, 4}; + std::array downsampling_ratio = {4, 5, 6, 8}; // reverse of upsampling_ratio + // vector quantizer + float frame_rate = 12.5; + int audio_channels = 1; + int codebook_size = 2048; + int codebook_dim = 256; + int n_semantic_components = 1; + int n_acoustic_components = 31; + // decode + float trim_right_ratio = 1.0f; + int n_codes_per_frame = (sliding_window / 2) * (n_semantic_components + n_acoustic_components); +} mimi_config; + +// Adapted from https://github.com/ngxson/ggml-easy/blob/master/ggml-easy.h +struct mimi_ggml_ctx { + gguf_context * ctx_gguf = nullptr; + ggml_context * ctx_data = nullptr; + ggml_context * ctx_gf = nullptr; + + // CPU-only for now, as many kernels are missing and we actually get less performance with GPU + ggml_backend_t backend = nullptr; + ggml_backend_buffer_t buf = nullptr; + ggml_backend_sched_ptr sched; + + ggml_cgraph * gf = nullptr; + std::vector buf_compute_meta; + int max_nodes = 16 * 1024; + + std::unordered_map tensors; + + mimi_ggml_ctx() { + backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + auto buft = ggml_backend_get_default_buffer_type(backend); + sched.reset( + ggml_backend_sched_new(&backend, &buft, 1, max_nodes, false) + ); + buf_compute_meta.resize(max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); + } + + void load_gguf(const char * fname) { + ggml_context * meta = nullptr; + + gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &meta, + }; + + ctx_gguf = gguf_init_from_file(fname, params); + + // load tensors + const int n_tensors = gguf_get_n_tensors(ctx_gguf); + + std::vector read_buf; + ggml_init_params ggml_params = { + /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ctx_data = ggml_init(ggml_params); + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + ggml_free(meta); + throw std::runtime_error("cannot open model file for loading tensors"); + } + + // add tensors to context + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * t = ggml_get_tensor(meta, name); + ggml_tensor * cur = ggml_dup_tensor(ctx_data, t); + ggml_set_name(cur, name); + tensors.insert({name, cur}); + } + + // alloc memory and offload data + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_data, buft); + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + const size_t offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i); + // printf("%s: Loading tensor \"%s\"\n", __func__, name); + fin.seekg(offset, std::ios::beg); + if (!fin) { + ggml_free(meta); + throw std::runtime_error(string_format("failed to seek for tensor: %s", name)); + } + int num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + } + printf("%s: Loaded %d tensors from %s\n", __func__, n_tensors, fname); + fin.close(); + + ggml_free(meta); + } + + /** + * Build a cgraph using the given builder function. + * + * The built cgraph will be stored in `ctx.gf` + */ + void build_graph(std::function builder_fn) { + ggml_free(ctx_gf); + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_gf = ggml_init(params); + ggml_backend_sched_reset(sched.get()); + gf = ggml_new_graph_custom(ctx_gf, max_nodes, false); + + builder_fn(ctx_gf, gf); + ggml_backend_sched_alloc_graph(sched.get(), gf); + } + + ggml_status compute() { + ggml_status status = ggml_backend_sched_graph_compute(sched.get(), gf); + return status; + } + + void set_tensor_data(const std::string & name, const void * data) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + ggml_backend_tensor_set(t, data, 0, ggml_nbytes(t)); + } + + std::pair> get_tensor_data(const std::string & name) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + std::vector data(ggml_nbytes(t)); + ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + return std::make_pair(t, data); + } + + ggml_tensor * get_weight(const char *fmt, ...) { + std::vector str(128); + va_list va; + va_start(va, fmt); + vsnprintf(str.data(), 128, fmt, va); + va_end(va); + auto it = tensors.find(str.data()); + if (it == tensors.end()) { + throw std::runtime_error(string_format("weight tensor not found: %s", str.data())); + } + return it->second; + } + + ~mimi_ggml_ctx() { + ggml_free(ctx_data); + gguf_free(ctx_gguf); + ggml_backend_buffer_free(buf); + } +}; + +/////////////////////////////////////////////////////////////////////////// +// extension to ggml.h +// TODO: add these ops to the library (ofc with a more optimized kernel) + + +// mode: (0) constant, (1) reflect, (2) replicate, (3) circular +// value is only used in "constant" +// only "constant" with 0.0f and "replicate" are implemented here +static ggml_tensor * ggml_pad_ext(ggml_context * ctx0, ggml_tensor * x, int mode, + int64_t pad_left, int64_t pad_right, float value = 0.0f) { + GGML_ASSERT(value == 0.0f); // we can technically use ggml_arange, but for simplication we only support 0.0f + GGML_ASSERT(mode == 0 || mode == 2); + if (pad_left > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_left, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], 0); // get first column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, tmp, x, 0); + } + if (pad_right > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_right, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + int64_t last = x->ne[0] - 1; + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], last * ggml_element_size(x)); // get last column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, x, tmp, 0); + } + return x; +} + + + + +/////////////////////////////////////////////////////////////////////////// +// MimiConv and MimiConvTranspose + +static int64_t div_ceil(int64_t a, int64_t b) { + return a / b + (a % b ? 1 : 0); +} + +static ggml_tensor * mimi_conv_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool pad_zero = true) { + int64_t kernel_size = (kernel->ne[0] - 1) * dilation + 1; + int64_t p_total = kernel_size - stride; // padding total + int64_t p_half = p_total / 2; + + int64_t n_frames = div_ceil(x->ne[0] - kernel_size + p_total, stride); + int64_t ideal_len = n_frames * stride + kernel_size - p_total; + int64_t p_extra = ideal_len - x->ne[0]; + + int64_t p_right = (mimi_config.causal ? 0 : p_half) + p_extra; + int64_t p_left = p_total - (mimi_config.causal ? 0 : p_half); + + x = ggml_pad_ext(ctx0, x, pad_zero ? 0 : 2, p_left, p_right); + + x = ggml_conv_1d(ctx0, kernel, x, stride, 0, dilation); + if (bias) { + x = ggml_add(ctx0, x, bias); + } + ggml_set_name(x, "mimi_conv_1d"); + return x; +} + +static ggml_tensor * mimi_conv_transpose_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool depthwise) { + GGML_ASSERT(x->ne[1] == kernel->ne[2]); + int64_t n_rows = x->ne[1]; + int64_t kernel_size = kernel->ne[0]; + int64_t p_total = kernel_size - stride; // padding total + + int64_t p_right = mimi_config.causal + ? (float)p_total / mimi_config.trim_right_ratio + : p_total / 2; + int64_t p_left = p_total - p_right; + + ggml_tensor * out = nullptr; + + if (depthwise) { + for (int64_t ir = 0; ir < n_rows; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, x, + x->ne[0], ir*x->ne[0]*ggml_element_size(x)); + ggml_tensor * krn = ggml_view_1d(ctx0, kernel, + kernel->ne[0], ir*kernel->ne[0]*ggml_element_size(kernel)); + row = ggml_conv_transpose_1d(ctx0, krn, row, stride, 0, dilation); + // unpad (remove p_right and p_left columns) + row = ggml_view_1d(ctx0, row, row->ne[0] - p_total, p_left*ggml_element_size(row)); + + // TODO: concat can be slow, we should use ggml_view_1d/ggml_cpy to avoid realloc + out = out ? ggml_concat(ctx0, out, row, 1) : row; + } + + } else { + out = ggml_conv_transpose_1d(ctx0, kernel, x, stride, 0, dilation); + // unpad + out = ggml_view_2d(ctx0, out, + out->ne[0] - p_total, out->ne[1], + out->nb[1], p_left*ggml_element_size(out)); + } + + if (bias) { + out = ggml_add(ctx0, out, bias); + } + + return out; +} + + + +/////////////////////////////////////////////////////////////////////////// + +// based on MimiEncoder +// SEANet encoder as used by Mimi. +struct mimi_encoder_decoder { + mimi_ggml_ctx & ctx; + struct layer { + bool is_elu = false; + bool is_resnet = false; + bool is_transposed_conv = false; + ggml_tensor * conv_0_w = nullptr; + ggml_tensor * conv_0_b = nullptr; + ggml_tensor * conv_1_w = nullptr; + ggml_tensor * conv_1_b = nullptr; + int stride = 1; + }; + std::vector layers; + + std::array repeated_pattern = {1, 4, 7, 10}; + + mimi_encoder_decoder(mimi_ggml_ctx & ctx): ctx(ctx) { + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.0.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.0.conv.bias"), + }); + for (int i = 0; i < (int)repeated_pattern.size(); ++i) { + int i_start = repeated_pattern[i]; + // upsampling layers + layers.push_back({ + .is_elu = true, // layer (i_start) + }); + layers.push_back({ + .is_transposed_conv = true, + .conv_0_w = ctx.get_weight("decoder.layers.%d.conv.weight", i_start + 1), + .conv_0_b = ctx.get_weight("decoder.layers.%d.conv.bias", i_start + 1), + .stride = mimi_config.upsampling_ratio[i], + }); + // residual layers + layers.push_back({ + .is_resnet = true, + .conv_0_w = ctx.get_weight("decoder.layers.%d.block.1.conv.weight", i_start + 2), + .conv_0_b = ctx.get_weight("decoder.layers.%d.block.1.conv.bias", i_start + 2), + .conv_1_w = ctx.get_weight("decoder.layers.%d.block.3.conv.weight", i_start + 2), + .conv_1_b = ctx.get_weight("decoder.layers.%d.block.3.conv.bias", i_start + 2), + }); + } + layers.push_back({ + .is_elu = true, // layer 13 + }); + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.14.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.14.conv.bias"), + }); + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input) { + ggml_tensor * x = input; + + for (auto & layer : layers) { + if (layer.is_elu) { + x = ggml_elu(ctx0, x); + } else if (layer.is_resnet) { + ggml_tensor * residual = x; + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, 1, 1); + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_1_w, layer.conv_1_b, 1, 1); + x = ggml_add(ctx0, x, residual); + } else { + x = layer.is_transposed_conv + ? mimi_conv_transpose_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1, false) + : mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1); + } + } + + return x; + } +}; + +struct mimi_transformer { + struct layer { + ggml_tensor * inp_norm_w = nullptr; + ggml_tensor * inp_norm_b = nullptr; + + ggml_tensor * attn_q = nullptr; + ggml_tensor * attn_k = nullptr; + ggml_tensor * attn_v = nullptr; + ggml_tensor * attn_o = nullptr; + ggml_tensor * attn_post_norm_w = nullptr; + ggml_tensor * attn_post_norm_b = nullptr; + ggml_tensor * attn_layer_scale = nullptr; + + ggml_tensor * ffn_up = nullptr; + ggml_tensor * ffn_down = nullptr; + ggml_tensor * mlp_layer_scale = nullptr; + }; + std::vector layers; + + mimi_transformer(mimi_ggml_ctx & ctx, const char * prefix, int n_layers) { + for (int il = 0; il < n_layers; il++) { + layers.push_back({ + .inp_norm_w = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.weight", prefix, il), + .inp_norm_b = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.bias", prefix, il), + + .attn_q = ctx.get_weight("%s_transformer.layers.%d.self_attn.q_proj.weight", prefix, il), + .attn_k = ctx.get_weight("%s_transformer.layers.%d.self_attn.k_proj.weight", prefix, il), + .attn_v = ctx.get_weight("%s_transformer.layers.%d.self_attn.v_proj.weight", prefix, il), + .attn_o = ctx.get_weight("%s_transformer.layers.%d.self_attn.o_proj.weight", prefix, il), + .attn_post_norm_w = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.weight", prefix, il), + .attn_post_norm_b = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.bias", prefix, il), + .attn_layer_scale = ctx.get_weight("%s_transformer.layers.%d.self_attn_layer_scale.scale", prefix, il), + + .ffn_up = ctx.get_weight("%s_transformer.layers.%d.mlp.fc1.weight", prefix, il), + .ffn_down = ctx.get_weight("%s_transformer.layers.%d.mlp.fc2.weight", prefix, il), + .mlp_layer_scale = ctx.get_weight("%s_transformer.layers.%d.mlp_layer_scale.scale", prefix, il), + }); + } + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input, ggml_tensor * inp_pos) { + int n_tokens = input->ne[1]; + ggml_tensor * x = input; + + auto layer_norm = [&](ggml_tensor * x, ggml_tensor * w, ggml_tensor * b) { + x = ggml_norm(ctx0, x, mimi_config.norm_eps); + x = ggml_mul(ctx0, x, w); + x = ggml_add(ctx0, x, b); + return x; + }; + + ggml_tensor * residual = input; + + for (auto & layer : layers) { + residual = x; + + // input layer norm + x = layer_norm(x, layer.inp_norm_w, layer.inp_norm_b); + + // self attention + { + ggml_tensor * q = ggml_mul_mat(ctx0, layer.attn_q, x); + ggml_tensor * k = ggml_mul_mat(ctx0, layer.attn_k, x); + ggml_tensor * v = ggml_mul_mat(ctx0, layer.attn_v, x); + + int n_embd_head = mimi_config.n_embd / mimi_config.n_head; + q = ggml_reshape_3d(ctx0, q, n_embd_head, mimi_config.n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, mimi_config.n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, mimi_config.n_head_kv, n_tokens); + + int n_rot = n_embd_head; + q = ggml_rope_inplace(ctx0, q, inp_pos, n_rot, 0); + q = ggml_cont(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3)); + + k = ggml_rope_inplace(ctx0, k, inp_pos, n_rot, 0); + k = ggml_cont(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3)); + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); // mimic behavior of llama.cpp + kq = ggml_scale_inplace(ctx0, kq, 1.0f / std::sqrt(n_embd_head)); + ggml_tensor * kq_masked = ggml_diag_mask_inf_inplace(ctx0, kq, n_tokens); + kq = ggml_soft_max_inplace(ctx0, kq_masked); + + v = ggml_cont(ctx0, ggml_permute(ctx0, v, 1, 2, 0, 3)); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + kqv = ggml_reshape_3d(ctx0, kqv, n_embd_head, n_tokens, mimi_config.n_head); + kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + kqv = ggml_cont_2d(ctx0, kqv, mimi_config.n_embd, n_tokens); + + x = ggml_mul_mat(ctx0, layer.attn_o, kqv); + } + + // residual + x = ggml_mul(ctx0, x, layer.attn_layer_scale); + x = ggml_add(ctx0, x, residual); + + residual = x; + x = layer_norm(x, layer.attn_post_norm_w, layer.attn_post_norm_b); + + // mlp + { + x = ggml_mul_mat(ctx0, layer.ffn_up, x); + x = ggml_gelu(ctx0, x); + x = ggml_mul_mat(ctx0, layer.ffn_down, x); + } + + // residual + x = ggml_mul(ctx0, x, layer.mlp_layer_scale); + x = ggml_add(ctx0, x, residual); + } + + return x; + } +}; + +struct mimi_residual_vector_quantizer { + struct component { + ggml_tensor * codebook; + }; + + ggml_tensor * semantic_inp_proj; + std::vector semantic_components; + ggml_tensor * semantic_out_proj; + + ggml_tensor * acoustic_inp_proj; + std::vector acoustic_components; + ggml_tensor * acoustic_out_proj; + + mimi_residual_vector_quantizer(mimi_ggml_ctx & ctx) { + semantic_inp_proj = ctx.get_weight("quantizer.semantic_rvq.input_proj.weight"); + semantic_out_proj = ctx.get_weight("quantizer.semantic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_semantic_components; i++) { + semantic_components.push_back({ + .codebook = ctx.get_weight("quantizer.semantic_rvq.layers.%d.codebook", i), + }); + } + acoustic_inp_proj = ctx.get_weight("quantizer.acoustic_rvq.input_proj.weight"); + acoustic_out_proj = ctx.get_weight("quantizer.acoustic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_acoustic_components; i++) { + acoustic_components.push_back({ + .codebook = ctx.get_weight("quantizer.acoustic_rvq.layers.%d.codebook", i), + }); + } + } + + // the input has shape [n_codes, n_codes_per_embd] + // first row is semantic, the rest are acoustic + // example: [ [semantic], [acoustic1], [acoustic2], ... ] + ggml_tensor * decode(ggml_context * ctx0, ggml_tensor * input) { + GGML_ASSERT(input->type == GGML_TYPE_I32); + + size_t n_semantic = semantic_components.size(); + int64_t n_codes_per_embd = (n_semantic + acoustic_components.size()); + int64_t n_codes = input->ne[0] / n_codes_per_embd; + + GGML_ASSERT(input->ne[0] % n_codes_per_embd == 0); + + ggml_tensor * out_s = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + ggml_tensor * out_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + out_s = ggml_scale(ctx0, out_s, 0.0f); // clear + out_a = ggml_scale(ctx0, out_a, 0.0f); // clear + + for (size_t ir = 0; ir < (size_t)n_codes_per_embd; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, input, n_codes, ir*n_codes*ggml_element_size(input)); + if (ir < n_semantic) { + // semantic + ggml_tensor * codebook = semantic_components[ir].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_s = ggml_add(ctx0, out_s, embd); + } else { + // acoustic + ggml_tensor * codebook = acoustic_components[ir-n_semantic].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_a = ggml_add(ctx0, out_a, embd); + } + } + + out_s = ggml_mul_mat(ctx0, semantic_out_proj, out_s); + out_a = ggml_mul_mat(ctx0, acoustic_out_proj, out_a); + + return ggml_add(ctx0, out_s, out_a); + } +}; + + +mimi_model::mimi_model(const char * fname, bool verbose) : verbose(verbose) { + ctx.reset(new mimi_ggml_ctx()); + ctx->load_gguf(fname); + + // initialize components + seanet_dec .reset(new mimi_encoder_decoder(*ctx)); + transformer_dec.reset(new mimi_transformer(*ctx, "decoder", mimi_config.num_hidden_layers)); + quantizer .reset(new mimi_residual_vector_quantizer(*ctx)); +} + +mimi_model::~mimi_model() { +} + +std::vector mimi_model::decode_frame(const std::vector & codes, int & n_past) { + // build cgraph + int n_pos = -1; + int n_codes = codes.size(); + int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd"); + + ctx->build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf) { + ggml_tensor * inp_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_codes); + ggml_set_name(inp_dec, "inp_dec"); + ggml_set_input(inp_dec); + + // RVQ + ggml_tensor * embeddings = quantizer->decode(ctx_gf, inp_dec); + + // upsample + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = mimi_conv_transpose_1d(ctx_gf, embeddings, ctx->get_weight("upsample.conv.weight"), nullptr, 2, 1, true); + + // transformer + n_pos = embeddings->ne[0]; + ggml_tensor * pos_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_dec, "pos_dec"); + ggml_set_input(pos_dec); + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = transformer_dec->forward(ctx_gf, embeddings, pos_dec); + + // SEANET decoder + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + ggml_tensor * output = seanet_dec->forward(ctx_gf, embeddings); + + ggml_set_name(output, "output"); + ggml_set_output(output); + ggml_build_forward_expand(gf, output); + }); + + // position data + GGML_ASSERT(n_pos <= mimi_config.sliding_window); + std::vector pos_data(n_pos); + for (int i = 0; i < (int)pos_data.size(); i++) { + pos_data[i] = i + n_past; + } + if (verbose) { + printf("%s: n_pos: %d, n_past: %d\n", __func__, n_pos, n_past); + } + n_past += n_pos; + ctx->set_tensor_data("pos_dec", pos_data.data()); + + // code data + ctx->set_tensor_data("inp_dec", codes.data()); + + ctx->compute(); + + auto output = ctx->get_tensor_data("output"); + // auto output_tensor = output.first; + auto output_data = output.second; + // printf("Output shape: [%lld, %lld]\n", output_tensor->ne[0], output_tensor->ne[1]); + + std::vector wav_data(output_data.size() / sizeof(float)); + for (size_t i = 0; i < wav_data.size(); i++) { + wav_data[i] = ((float *)output_data.data())[i]; + } + + return wav_data; +} + +std::vector mimi_model::decode(const std::vector & codes) { + std::vector output; + + if (verbose) { + printf("%s: n_codes: %zu\n", __func__, codes.size()); + } + + int64_t t_start = ggml_time_ms(); + int n_frames = 0; + + int n_past = 0; + for (size_t i = 0; i < codes.size(); i += mimi_config.n_codes_per_frame) { + size_t remaining = std::min((size_t)mimi_config.n_codes_per_frame, codes.size() - i); + std::vector frame(codes.begin() + i, codes.begin() + i + remaining); + + auto wav_data = decode_frame(frame, n_past); + output.insert(output.end(), wav_data.begin(), wav_data.end()); + + n_frames++; + } + + int64_t t_end = ggml_time_ms(); + if (verbose) { + printf("%s: n_frames: %d, time: %" PRId64 "ms, per_frame: %" PRId64 "ms\n", __func__, n_frames, t_end - t_start, (t_end - t_start) / n_frames); + } + + return output; +} + +std::vector mimi_model::transpose_input(const std::vector & codes) { + int n_codes = codes.size(); + int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd"); + + std::vector codes_T(n_codes_per_embd * n_codes); + for (int i = 0; i < n_codes / n_codes_per_embd; i++) { + for (int j = 0; j < n_codes_per_embd; j++) { + int src_idx = i * n_codes_per_embd + j; + int dst_idx = j * (n_codes / n_codes_per_embd) + i; + codes_T[dst_idx] = codes[src_idx]; + } + } + + return codes_T; +} + +int mimi_model::get_sample_rate() const { + return mimi_config.sample_rate; +} diff --git a/examples/tts/mimi-model.h b/examples/tts/mimi-model.h new file mode 100644 index 0000000000000..96945981513c0 --- /dev/null +++ b/examples/tts/mimi-model.h @@ -0,0 +1,38 @@ +#pragma once + +#include "ggml.h" +#include +#include + +struct mimi_ggml_ctx; +struct mimi_encoder_decoder; +struct mimi_transformer; +struct mimi_residual_vector_quantizer; + +struct mimi_model { + bool verbose = false; + std::unique_ptr ctx; + + std::unique_ptr seanet_dec; + std::unique_ptr transformer_dec; + std::unique_ptr quantizer; + + mimi_model(const char * fname, bool verbose = false); + ~mimi_model(); + + int get_sample_rate() const; + + // transpose layout: + // - from: (1 semantic code followed by 31 acoustic codes) repeast N times + // - to: N semantic codes followed by (N*31) acoustic codes + std::vector transpose_input(const std::vector & codes); + + // layout of codes: N semantic codes followed by (N*31) acoustic codes + std::vector decode(const std::vector & codes); + + // TODO: implement encoding pass + // std::vector encode(const std::vector & wav_data); + +private: + std::vector decode_frame(const std::vector & codes, int & n_past); +}; diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp new file mode 100644 index 0000000000000..502e0150634b7 --- /dev/null +++ b/examples/tts/mimi.cpp @@ -0,0 +1,112 @@ +#include "common.h" +#include "mimi-model.h" + +#include +#include +#include // strcmp + + +/** + * This file is used for testing and showcase how to use "mimi_model" class. + * Please keep it simple and easy to understand. + */ + +int main(int argc, const char ** argv) { + if (argc < 3) { + fprintf(stderr, "Usage: %s model.gguf codes.txt [output.wav]\n", argv[0]); + fprintf(stderr, " Format of codes.txt file: one code per line\n"); + fprintf(stderr, " Replace codes.txt with dummy0 and dummy1 for testing\n"); + fprintf(stderr, " dummy0: using code 1, 2, 3,..., 96, used for logits matching\n"); + fprintf(stderr, " dummy1: using code that will outputs 'wah hello there' sound\n"); + return 1; + } + + const char * model_path = argv[1]; + const char * codes_path = argv[2]; + const char * out_path = argc < 4 ? "output.wav" : argv[3]; + + // load codes + std::vector codes; + if (strcmp(codes_path, "dummy0") == 0) { + printf("Using dummy0 codes\n"); + codes.resize(32 * 3); // [n_codes_per_embd = 32, n_codes = 3] + for (int i = 0; i < (int)codes.size(); i++) { + codes[i] = i; + } + } else if (strcmp(codes_path, "dummy1") == 0) { + printf("Using dummy1 codes\n"); + codes = { + 1049 ,1415 ,1962 ,914 ,1372 ,704 ,1922 ,2036 ,288 ,968 ,193 ,1139 ,897 ,897 ,1243 ,1511 , + 1597 ,175 ,1280 ,1202 ,1911 ,85 ,47 ,692 ,632 ,251 ,1553 ,1735 ,1577 ,132 ,471 ,433 , + 1325 ,1539 ,1943 ,1601 ,141 ,257 ,564 ,1435 ,876 ,1096 ,636 ,61 ,1497 ,1010 ,485 ,284 , + 839 ,776 ,878 ,1719 ,1069 ,1302 ,893 ,2005 ,875 ,908 ,586 ,2001 ,186 ,1932 ,1765 ,721 , + 592 ,1046 ,1588 ,1670 ,1485 ,1141 ,34 ,1465 ,1156 ,1938 ,435 ,753 ,1418 ,277 ,391 ,1741 , + 1440 ,117 ,723 ,412 ,642 ,1717 ,131 ,37 ,345 ,112 ,1979 ,2034 ,1822 ,1536 ,1281 ,56 , + 1341 ,803 ,568 ,568 ,1370 ,1995 ,1063 ,892 ,273 ,895 ,1226 ,354 ,1726 ,1541 ,1607 ,615 , + 985 ,1499 ,1736 ,1838 ,702 ,1345 ,1657 ,511 ,1774 ,1787 ,945 ,1927 ,947 ,952 ,1418 ,916 , + 1239 ,1457 ,1021 ,341 ,284 ,882 ,474 ,1559 ,1923 ,273 ,1330 ,1406 ,1782 ,19 ,116 ,887 , + 1146 ,1307 ,983 ,1237 ,1407 ,1350 ,1960 ,1255 ,878 ,1979 ,1500 ,1939 ,1415 ,88 ,1702 ,1253 , + 1778 ,2 ,10 ,1279 ,999 ,1549 ,1049 ,373 ,1355 ,1200 ,1466 ,1009 ,75 ,2042 ,1725 ,916 , + 1636 ,1135 ,833 ,830 ,1758 ,2015 ,1275 ,1675 ,287 ,744 ,89 ,430 ,1724 ,1232 ,1692 ,535 , + 1485 ,1287 ,973 ,1815 ,314 ,2020 ,424 ,1085 ,982 ,1994 ,1563 ,1269 ,1769 ,1681 ,1082 ,1666 , + 1622 ,1039 ,1209 ,32 ,679 ,732 ,976 ,1462 ,805 ,402 ,1150 ,170 ,1529 ,2013 ,350 ,1175 , + 757 ,1124 ,1091 ,1369 ,1061 ,415 ,1217 ,1135 ,1360 ,1578 ,1205 ,1785 ,1835 ,1241 ,14 ,716 , + 480 ,716 ,681 ,1686 ,1624 ,335 ,865 ,1356 ,1688 ,307 ,366 ,541 ,1262 ,1167 ,59 ,269 , + 1899 ,1798 ,1606 ,1307 ,1549 ,1814 ,114 ,483 ,958 ,1919 ,1179 ,898 ,834 ,1526 ,386 ,447 , + 1481 ,201 ,779 ,419 ,430 ,1451 ,1000 ,156 ,1062 ,615 ,1353 ,414 ,1214 ,1487 ,882 ,32 , + 840 ,1517 ,334 ,1143 ,823 ,454 ,725 ,1298 ,1325 ,649 ,1737 ,913 ,685 ,761 ,2010 ,63 , + 1397 ,1299 ,765 ,1158 ,1809 ,1299 ,1585 ,1776 ,625 ,1539 ,830 ,1563 ,461 ,308 ,1438 ,321 , + 82 ,886 ,1836 ,325 ,1976 ,761 ,359 ,1136 ,1720 ,2036 ,904 ,719 ,526 ,1567 ,145 ,1860 , + 1565 ,1786 ,1400 ,1696 ,232 ,1736 ,512 ,518 ,1895 ,1854 ,1584 ,1393 ,1869 ,1702 ,789 ,1986 , + 116 ,521 ,150 ,1597 ,727 ,1916 ,815 ,1826 ,1382 ,653 ,1596 ,286 ,1373 ,177 ,1397 ,1009 , + 1449 ,353 ,877 ,93 ,266 ,1853 ,1255 ,872 ,1974 ,556 ,1885 ,857 ,992 ,5 ,1921 ,1849 , + 1038 ,1912 ,464 ,795 ,747 ,56 ,124 ,431 ,1868 ,609 ,855 ,1522 ,912 ,1709 ,1507 ,1062 , + 1015 ,1357 ,1487 ,4 ,253 ,1871 ,933 ,215 ,1228 ,633 ,1306 ,2024 ,1453 ,900 ,457 ,471 , + 436 ,1311 ,870 ,1032 ,134 ,984 ,1983 ,1103 ,1627 ,1627 ,414 ,1845 ,583 ,1699 ,1458 ,2018 , + 150 ,450 ,1114 ,369 ,267 ,1273 ,1136 ,1578 ,1063 ,1820 ,120 ,779 ,652 ,1266 ,1929 ,1213 , + 159 ,297 ,1703 ,819 ,93 ,247 ,1366 ,144 ,1617 ,1428 ,812 ,121 ,1637 ,1620 ,289 ,1557 , + 1414 ,971 ,476 ,1685 ,428 ,1802 ,653 ,1290 ,614 ,1663 ,1528 ,1344 ,798 ,1027 ,1305 ,990 , + 1740 ,1154 ,1839 ,912 ,731 ,602 ,1064 ,1508 ,834 ,1387 ,252 ,745 ,1034 ,1102 ,965 ,696 , + 1971 ,1729 ,666 ,282 ,1993 ,1551 ,1703 ,1124 ,1628 ,1725 ,107 ,808 ,1096 ,1753 ,500 ,677 , + }; + } else { + std::ifstream fin(codes_path); + if (!fin) { + fprintf(stderr, "Error: cannot open codes file: %s\n", codes_path); + return 1; + } + std::string line; + while (std::getline(fin, line)) { + // Skip empty lines + if (line.empty()) continue; + try { + int code = std::stoi(line); + codes.push_back(code); + } catch (const std::exception& e) { + fprintf(stderr, "Error parsing code: %s\n", line.c_str()); + return 1; + } + } + if (codes.empty()) { + fprintf(stderr, "Error: no codes found in file: %s\n", codes_path); + return 1; + } + + printf("Loaded %d codes from %s\n", (int)codes.size(), codes_path); + } + + mimi_model model(model_path, true); + std::vector wav_data = model.decode(codes); + + // print first 20 values + printf("Number of output samples: %d\n", (int)wav_data.size()); + printf("First 20 samples:\n"); + for (int i = 0; i < 20; i++) { + printf("%2.4f, ", wav_data[i]); + } + printf("...\n"); + + // write to wav + printf("Writing to %s\n", out_path); + save_wav16(out_path, wav_data, model.get_sample_rate()); +} diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 4cc42e1674ccc..b3461b5d273ef 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -71,46 +71,6 @@ static void print_usage(int, char ** argv) { LOG("\n"); } -struct wav_header { - char riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t chunk_size; - char wave[4] = {'W', 'A', 'V', 'E'}; - char fmt[4] = {'f', 'm', 't', ' '}; - uint32_t fmt_chunk_size = 16; - uint16_t audio_format = 1; // PCM - uint16_t num_channels = 1; // Mono - uint32_t sample_rate; - uint32_t byte_rate; - uint16_t block_align; - uint16_t bits_per_sample = 16; - char data[4] = {'d', 'a', 't', 'a'}; - uint32_t data_size; -}; - -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { - std::ofstream file(fname, std::ios::binary); - if (!file) { - LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); - return false; - } - - wav_header header; - header.sample_rate = sample_rate; - header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); - header.block_align = header.num_channels * (header.bits_per_sample / 8); - header.data_size = data.size() * (header.bits_per_sample / 8); - header.chunk_size = 36 + header.data_size; - - file.write(reinterpret_cast(&header), sizeof(header)); - - for (const auto & sample : data) { - int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); - file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); - } - - return file.good(); -} - static void fill_hann_window(int length, bool periodic, float * output) { int offset = -1; if (periodic) {