diff --git a/mlx_audio/tts/generate.py b/mlx_audio/tts/generate.py index fc1d2a0a..80b10a08 100644 --- a/mlx_audio/tts/generate.py +++ b/mlx_audio/tts/generate.py @@ -35,6 +35,12 @@ def parse_args(): "--join_audio", action="store_true", help="Join all audio files into one" ) parser.add_argument("--play", action="store_true", help="Play the output audio") + parser.add_argument( + "--ref_audio", type=str, default=None, help="Path to reference audio" + ) + parser.add_argument( + "--ref_text", type=str, default=None, help="Caption for reference audio" + ) args = parser.parse_args() if args.text is None: @@ -50,6 +56,29 @@ def parse_args(): def main(): args = parse_args() try: + # load reference audio for voice matching if specified + + ref_audio = None + ref_text = None + + if args.ref_audio: + if not os.path.exists(args.ref_audio): + raise FileNotFoundError( + f"Reference audio file not found: {args.ref_audio}" + ) + if not args.ref_text: + raise ValueError( + "Reference text is required when using reference audio." + ) + + ref_audio, ref_sr = sf.read(args.ref_audio) + if ref_sr != 24000: + raise ValueError( + f"Reference audio sample rate must be 24000 Hz, but got {ref_sr} Hz." + ) + ref_audio = mx.array(ref_audio, dtype=mx.float32) + ref_text = args.ref_text + player = AudioPlayer() if args.play else None model = load_model(model_path=args.model) @@ -66,6 +95,8 @@ def main(): voice=args.voice, speed=args.speed, lang_code=args.lang_code, + ref_audio=ref_audio, + ref_text=ref_text, verbose=True, ) print( diff --git a/mlx_audio/tts/models/kokoro/kokoro.py b/mlx_audio/tts/models/kokoro/kokoro.py index 5a29100c..29f95dd2 100644 --- a/mlx_audio/tts/models/kokoro/kokoro.py +++ b/mlx_audio/tts/models/kokoro/kokoro.py @@ -249,6 +249,7 @@ def generate( lang_code: str = "af", split_pattern: str = r"\n+", verbose: bool = False, + **kwargs, ): pipeline = KokoroPipeline( model=self, diff --git a/mlx_audio/tts/models/sesame/__init__.py b/mlx_audio/tts/models/sesame/__init__.py new file mode 100644 index 00000000..b7ac7f94 --- /dev/null +++ b/mlx_audio/tts/models/sesame/__init__.py @@ -0,0 +1,3 @@ +from .model import Model + +__all__ = ["Model"] diff --git a/mlx_audio/tts/models/sesame/attention.py b/mlx_audio/tts/models/sesame/attention.py new file mode 100644 index 00000000..14b96895 --- /dev/null +++ b/mlx_audio/tts/models/sesame/attention.py @@ -0,0 +1,195 @@ +import math +from typing import Any, Optional + +import mlx.core as mx +from mlx import nn +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.llama import ModelArgs + + +class Llama3ScaledRoPE(nn.Module): + def __init__( + self, + dim: int, + max_seq_len: int = 2048, + base: float = 500_000.0, + scale_factor: float = 32.0, + low_freq_factor: int = 1, + high_freq_factor: int = 4, + old_context_len: int = 8192, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + + self.scale_factor = scale_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.old_context_len = old_context_len + self.is_cache_built = False + self.rope_init() + + def rope_init(self): + freqs = 1.0 / ( + self.base + ** ( + mx.arange(0, self.dim, 2)[: (self.dim // 2)].astype(mx.float32) + / self.dim + ) + ) + + theta = self.apply_scaling( + freqs, + self.scale_factor, + self.low_freq_factor, + self.high_freq_factor, + self.old_context_len, + ) + self._theta = theta + self.build_rope_cache(self.max_seq_len) + self.is_cache_built = True + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + seq_idx = mx.arange(max_seq_len, dtype=self._theta.dtype) + idx_theta = mx.einsum("i, j -> ij", seq_idx, self._theta).astype(mx.float32) + cache = mx.stack([mx.cos(idx_theta), mx.sin(idx_theta)], axis=-1) + self._cache = cache + + def apply_scaling( + self, + freqs: mx.array, + scale_factor: float, + low_freq_factor: int, + high_freq_factor: int, + old_context_len: int, + ): + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + new_freqs = [] + for freq in freqs: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / scale_factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) + return mx.array(new_freqs, dtype=freqs.dtype) + + def __call__(self, x: mx.array, *, offset: int) -> mx.array: + if not self.is_cache_built: + raise RuntimeError( + "RoPE cache is not built. Please call rope_init() first." + ) + + seq_len = x.shape[1] + rope_cache = ( + self._cache[:seq_len] + if offset is None + else self._cache[None, offset : offset + seq_len] + ) + xshaped = x.astype(mx.float32).reshape(*x.shape[:-1], -1, 2) + rope_cache = rope_cache.reshape(-1, xshaped.shape[1], 1, xshaped.shape[3], 2) + + x_out = mx.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + x_out = x_out.flatten(3) + return x_out.astype(x.dtype) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads or n_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = Llama3ScaledRoPE( + self.head_dim, + base=args.rope_theta, + scale_factor=args.rope_scaling.get("factor", 1.0), + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + b, s_x, _ = x.shape + y = x + + s_y = y.shape[1] if y is not None else 0 + + q = self.q_proj(x) + + q_per_kv = self.n_heads // self.n_kv_heads + q = q.reshape(b, s_x, self.n_kv_heads * q_per_kv, self.head_dim) + + if self.rope is not None: + q = self.rope(q, offset=cache.offset if cache else 0) + + q = q.swapaxes(1, 2) + + k = self.k_proj(y) + v = self.v_proj(y) + + k = k.reshape(b, s_y, -1, self.head_dim) + v = v.reshape(b, s_y, -1, self.head_dim) + if self.rope is not None: + k = self.rope(k, offset=cache.offset if cache else 0) + + k = k.swapaxes(1, 2) + v = v.swapaxes(1, 2) + + if cache: + k, v = cache.update_and_fetch(k, v) + + if self.n_heads != self.n_kv_heads: + q_per_kv = self.n_heads // self.n_kv_heads + + k = mx.expand_dims(k, axis=2) + v = mx.expand_dims(v, axis=2) + + k_expand_shape = (b, self.n_kv_heads, q_per_kv) + k.shape[3:] + v_expand_shape = (b, self.n_kv_heads, q_per_kv) + v.shape[3:] + + k = mx.broadcast_to(k, k_expand_shape) + v = mx.broadcast_to(v, v_expand_shape) + + k = k.reshape(b, self.n_kv_heads * q_per_kv, *k.shape[3:]) + v = v.reshape(b, self.n_kv_heads * q_per_kv, *v.shape[3:]) + + output = scaled_dot_product_attention( + q, k, v, cache=cache, scale=self.scale, mask=mask + ) + + output = output.swapaxes(1, 2).reshape(b, s_x, -1) + return self.o_proj(output) diff --git a/mlx_audio/tts/models/sesame/model.py b/mlx_audio/tts/models/sesame/model.py new file mode 100644 index 00000000..1daa2df9 --- /dev/null +++ b/mlx_audio/tts/models/sesame/model.py @@ -0,0 +1,479 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Callable, Dict, List, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.cache import make_prompt_cache +from mlx_lm.models.llama import LlamaModel +from mlx_lm.models.llama import ModelArgs as LlamaModelArgs +from mlx_lm.sample_utils import make_sampler +from tokenizers.processors import TemplateProcessing +from transformers import AutoTokenizer + +from mlx_audio.codec import Mimi + +from ..base import GenerationResult +from .attention import Attention + +try: + from .watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark +except ImportError: + print( + "Watermarking module not found. Please install silentcipher to use watermarking." + ) + +MIMI_REPO = "kyutai/moshiko-pytorch-bf16" +TOKENIZER_REPO = "unsloth/Llama-3.2-1B" + + +def create_causal_mask(seq_len: int) -> mx.array: + return mx.tril(mx.ones((seq_len, seq_len), dtype=mx.bool_)) + + +def index_causal_mask(mask: mx.array, input_pos: mx.array) -> mx.array: + mask_indexed = mx.take(mask, input_pos, axis=0) + + seq_len = input_pos.shape[1] + mask_indexed = mask_indexed[:, :, :seq_len] + + # reshape to (batch_size, 1, seq_len, seq_len) for broadcasting across heads + return mx.expand_dims(mask_indexed, axis=1) + + +@dataclass +class SesameModelArgs: + model_type: str + backbone_flavor: str + decoder_flavor: str + text_vocab_size: int + audio_vocab_size: int + audio_num_codebooks: int + + +def create_llama_model_args(flavor: str) -> LlamaModelArgs: + if flavor == "llama-1B": + return LlamaModelArgs( + model_type="llama", + num_hidden_layers=16, + num_attention_heads=32, + num_key_value_heads=8, + head_dim=64, + hidden_size=2048, + intermediate_size=8192, + rms_norm_eps=1e-5, + vocab_size=128_256, + max_position_embeddings=2048, + attention_bias=False, + mlp_bias=False, + rope_theta=500_000, + rope_scaling={ + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + ) + elif flavor == "llama-100M": + return LlamaModelArgs( + model_type="llama", + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=128, + hidden_size=1024, + intermediate_size=8192, + rms_norm_eps=1e-5, + vocab_size=128_256, + max_position_embeddings=2048, + attention_bias=False, + mlp_bias=False, + rope_theta=500_000, + rope_scaling={ + "factor": 32.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + "rope_type": "llama3", + }, + ) + else: + raise ValueError(f"Unknown flavor: {flavor}") + + +class SesameModel(nn.Module): + def __init__(self, config): + super().__init__() + args = SesameModelArgs(**config) + self.args = args + + backbone_args = create_llama_model_args(args.backbone_flavor) + decoder_args = create_llama_model_args(args.decoder_flavor) + + self.backbone = LlamaModel(backbone_args) + self.decoder = LlamaModel(decoder_args) + + backbone_dim = backbone_args.hidden_size + decoder_dim = decoder_args.hidden_size + + self.backbone.embed_tokens = nn.Identity() + self.decoder.embed_tokens = nn.Identity() + + for layer in self.backbone.layers: + layer.self_attn = Attention(backbone_args) + for layer in self.decoder.layers: + layer.self_attn = Attention(decoder_args) + + self.text_embeddings = nn.Embedding(args.text_vocab_size, backbone_dim) + self.audio_embeddings = nn.Embedding( + args.audio_vocab_size * args.audio_num_codebooks, backbone_dim + ) + + self.projection = nn.Linear(backbone_dim, decoder_dim, bias=False) + self.codebook0_head = nn.Linear(backbone_dim, args.audio_vocab_size, bias=False) + self.audio_head = mx.zeros( + (args.audio_num_codebooks - 1, decoder_dim, args.audio_vocab_size) + ) + + self._backbone_causal_mask = None + self._decoder_causal_mask = None + + self.backbone_cache = None + self.decoder_cache = None + self.caches_enabled = False + + def setup_caches(self, max_batch_size: int): + backbone_args = create_llama_model_args(self.args.backbone_flavor) + + self._backbone_causal_mask = create_causal_mask( + backbone_args.max_position_embeddings + ) + self._decoder_causal_mask = create_causal_mask(self.args.audio_num_codebooks) + + self.backbone_cache = make_prompt_cache(self.backbone) + self.decoder_cache = make_prompt_cache(self.decoder) + self.caches_enabled = True + + def caches_are_enabled(self): + return self.caches_enabled + + def reset_caches(self): + if self.backbone_cache is not None: + self.backbone_cache = make_prompt_cache(self.backbone) + + if self.decoder_cache is not None: + self.decoder_cache = make_prompt_cache(self.decoder) + + def generate_frame( + self, + tokens: mx.array, + tokens_mask: mx.array, + input_pos: mx.array, + sampler: Callable[..., mx.array], + ) -> mx.array: + assert self.caches_are_enabled(), "backbone caches are not enabled" + + curr_backbone_mask = index_causal_mask(self._backbone_causal_mask, input_pos) + embeds = self._embed_tokens(tokens) + masked_embeds = embeds * mx.expand_dims(tokens_mask, -1) + h = mx.sum(masked_embeds, axis=2) + h = self.backbone(h, mask=curr_backbone_mask, cache=self.backbone_cache) + + last_h = h[:, -1, :] + c0_logits = self.codebook0_head(last_h) + c0_sample = mx.expand_dims(sampler(c0_logits), axis=-1) + c0_embed = self._embed_audio(0, c0_sample) + + curr_h = mx.concat([mx.expand_dims(last_h, 1), c0_embed], axis=1) + curr_sample = c0_sample + curr_pos = mx.arange(curr_h.shape[1], dtype=mx.int32) + curr_pos = mx.expand_dims(curr_pos, 0) + curr_pos = mx.broadcast_to(curr_pos, (curr_h.shape[0], curr_h.shape[1])) + + # reset decoder cache for new frame + + self.decoder_cache = make_prompt_cache(self.decoder) + + for i in range(1, self.args.audio_num_codebooks): + curr_decoder_mask = index_causal_mask(self._decoder_causal_mask, curr_pos) + decoder_h = self.decoder( + self.projection(curr_h), + mask=curr_decoder_mask, + cache=self.decoder_cache, + ) + + ci_logits = mx.matmul(decoder_h[:, -1, :], self.audio_head[i - 1]) + ci_sample = mx.expand_dims(sampler(ci_logits), axis=-1) + ci_embed = self._embed_audio(i, ci_sample) + + curr_h = ci_embed + curr_sample = mx.concat([curr_sample, ci_sample], axis=1) + curr_pos = curr_pos[:, -1:] + 1 + + return curr_sample + + def _embed_audio(self, codebook: int, tokens: mx.array) -> mx.array: + return self.audio_embeddings(tokens + codebook * self.args.audio_vocab_size) + + def _embed_tokens(self, tokens: mx.array) -> mx.array: + text_embeds = self.text_embeddings(tokens[:, :, -1]) + text_embeds = mx.expand_dims(text_embeds, axis=-2) + + codebook_indices = mx.arange(self.args.audio_num_codebooks, dtype=mx.int32) + codebook_offsets = codebook_indices * self.args.audio_vocab_size + + audio_tokens = tokens[:, :, :-1] + mx.reshape(codebook_offsets, (1, 1, -1)) + audio_embeds_flat = self.audio_embeddings(audio_tokens.flatten()) + + audio_embeds = mx.reshape( + audio_embeds_flat, + (tokens.shape[0], tokens.shape[1], self.args.audio_num_codebooks, -1), + ) + + return mx.concat([audio_embeds, text_embeds], axis=-2) + + +@dataclass +class Segment: + speaker: int + text: str + # (num_samples,), sample_rate = 24_000 + audio: mx.array + + +def load_llama3_tokenizer(path_or_hf_repo: str): + tokenizer = AutoTokenizer.from_pretrained(path_or_hf_repo) + bos = tokenizer.bos_token + eos = tokenizer.eos_token + tokenizer._tokenizer.post_processor = TemplateProcessing( + single=f"{bos}:0 $A:0 {eos}:0", + pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1", + special_tokens=[ + (f"{bos}", tokenizer.bos_token_id), + (f"{eos}", tokenizer.eos_token_id), + ], + ) + return tokenizer + + +class Model(nn.Module): + def __init__( + self, + config: Dict, + ): + self.model = SesameModel(config) + self.model.setup_caches(1) + + self._text_tokenizer = load_llama3_tokenizer(TOKENIZER_REPO) + mimi = Mimi.from_pretrained(MIMI_REPO) + self._audio_tokenizer = mimi + + try: + self._watermarker = load_watermarker() + except Exception: + self._watermarker = None + + self.sample_rate = mimi.cfg.sample_rate + + def _tokenize_text_segment( + self, text: str, speaker: int + ) -> Tuple[mx.array, mx.array]: + frame_tokens = [] + frame_masks = [] + + text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}") + text_frame = mx.zeros((len(text_tokens), 33)).astype(mx.int32) + text_frame_mask = mx.zeros((len(text_tokens), 33)).astype(mx.bool_) + text_frame[:, -1] = mx.array(text_tokens) + text_frame_mask[:, -1] = True + + frame_tokens.append(text_frame) + frame_masks.append(text_frame_mask) + + return mx.concat(frame_tokens, axis=0), mx.concat(frame_masks, axis=0) + + def _tokenize_audio(self, audio: mx.array) -> Tuple[mx.array, mx.array]: + frame_tokens = [] + frame_masks = [] + + # (K, T) + audio_tokens = self._audio_tokenizer.encode( + mx.expand_dims(mx.expand_dims(audio, 0), 0) + )[0] + + # add EOS frame + eos_frame = mx.zeros((audio_tokens.shape[0], 1)) + audio_tokens = mx.concat([audio_tokens, eos_frame], axis=1) + + audio_frame = mx.zeros((audio_tokens.shape[1], 33)).astype(mx.int32) + audio_frame_mask = mx.zeros((audio_tokens.shape[1], 33)).astype(mx.bool_) + audio_frame[:, :-1] = audio_tokens.swapaxes(0, 1) + audio_frame_mask[:, :-1] = True + + frame_tokens.append(audio_frame) + frame_masks.append(audio_frame_mask) + + return mx.concat(frame_tokens, axis=0), mx.concat(frame_masks, axis=0) + + def _tokenize_segment(self, segment: Segment) -> Tuple[mx.array, mx.array]: + """ + Returns: + (seq_len, 33), (seq_len, 33) + """ + text_tokens, text_masks = self._tokenize_text_segment( + segment.text, segment.speaker + ) + audio_tokens, audio_masks = self._tokenize_audio(segment.audio) + + return mx.concat([text_tokens, audio_tokens], axis=0), mx.concat( + [text_masks, audio_masks], axis=0 + ) + + def sanitize(self, weights): + return weights + + def load_weights(self, weights): + self.model.load_weights(weights) + + def generate( + self, + text: str, + speaker: int = 0, + context: List[Segment] = [], + max_audio_length_ms: float = 90_000, + sampler: Callable[..., mx.array] = None, + ref_audio: mx.array = None, + ref_text: str = None, + **kwargs, + ): + self.model.reset_caches() + + # if reference audio is provided, use it as the first segment + + if len(context) == 0 and ref_audio is not None and ref_text is not None: + context = [Segment(speaker=speaker, text=ref_text, audio=ref_audio)] + + start_time = time.time() + + sampler = sampler or make_sampler(temp=0.9, top_k=50) + max_audio_frames = int(max_audio_length_ms / 80) + + tokens, tokens_mask = [], [] + for segment in context: + segment_tokens, segment_tokens_mask = self._tokenize_segment(segment) + tokens.append(segment_tokens) + tokens_mask.append(segment_tokens_mask) + + gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment( + text, speaker + ) + tokens.append(gen_segment_tokens) + tokens_mask.append(gen_segment_tokens_mask) + + prompt_tokens = mx.concat(tokens, axis=0).astype(mx.int32) + prompt_tokens_mask = mx.concat(tokens_mask, axis=0).astype(mx.bool_) + + samples = [] + curr_tokens = mx.expand_dims(prompt_tokens, axis=0) + curr_tokens_mask = mx.expand_dims(prompt_tokens_mask, axis=0) + curr_pos = mx.expand_dims(mx.arange(0, prompt_tokens.shape[0]), axis=0).astype( + mx.int32 + ) + + max_seq_len = 2048 - max_audio_frames + if curr_tokens.shape[1] >= max_seq_len: + raise ValueError( + f"Inputs too long, must be below max_seq_len - max_audio_frames: {max_seq_len}" + ) + + for _ in range(max_audio_frames): + sample = self.model.generate_frame( + curr_tokens, curr_tokens_mask, curr_pos, sampler + ) + if mx.all(sample == 0): + break # eos + + samples.append(sample) + + curr_tokens = mx.expand_dims( + mx.concat([sample, mx.zeros((1, 1)).astype(mx.int32)], axis=1), axis=1 + ) + curr_tokens_mask = mx.expand_dims( + mx.concat( + [ + mx.ones_like(sample).astype(mx.bool_), + mx.zeros((1, 1)).astype(mx.bool_), + ], + axis=1, + ), + axis=1, + ) + curr_pos = curr_pos[:, -1:] + 1 + + transposed = mx.transpose(mx.stack(samples), axes=[1, 2, 0]) + audio = self._audio_tokenizer.decode(transposed).squeeze(0).squeeze(0) + + # This applies an imperceptible watermark to identify audio as AI-generated. + # Watermarking ensures transparency, dissuades misuse, and enables traceability. + # Please be a responsible AI citizen and keep the watermarking in place. + # If using CSM 1B in another application, use your own private key and keep it secret. + if self._watermarker is not None: + audio = watermark( + self._watermarker, + audio, + self.sample_rate, + CSM_1B_GH_WATERMARK, + ) + audio = mx.array(audio, dtype=mx.float32) + + mx.eval(audio) + + segment_time = time.time() - start_time + + samples = audio.shape[0] if audio is not None else 0 + assert samples > 0, "No audio generated" + + # Calculate token count + token_count = curr_tokens.shape[2] + + # Calculate audio duration in seconds + sample_rate = 24000 # Assuming 24kHz sample rate, adjust if different + audio_duration_seconds = samples / sample_rate + + # Calculate real-time factor (RTF) + rtf = segment_time / audio_duration_seconds if audio_duration_seconds > 0 else 0 + + # Format duration as HH:MM:SS.mmm + duration_mins = int(audio_duration_seconds // 60) + duration_secs = int(audio_duration_seconds % 60) + duration_ms = int((audio_duration_seconds % 1) * 1000) + duration_hours = int(audio_duration_seconds // 3600) + duration_str = f"{duration_hours:02d}:{duration_mins:02d}:{duration_secs:02d}.{duration_ms:03d}" + + return [ + GenerationResult( + audio=audio, + samples=samples, + segment_idx=0, + token_count=token_count, + audio_duration=duration_str, + real_time_factor=round(rtf, 2), + prompt={ + "tokens": token_count, + "tokens-per-sec": ( + round(token_count / segment_time, 2) if segment_time > 0 else 0 + ), + }, + audio_samples={ + "samples": samples, + "samples-per-sec": ( + round(samples / segment_time, 2) if segment_time > 0 else 0 + ), + }, + processing_time_seconds=segment_time, + peak_memory_usage=mx.metal.get_peak_memory() / 1e9, + ) + ] diff --git a/mlx_audio/tts/models/sesame/watermarking.py b/mlx_audio/tts/models/sesame/watermarking.py new file mode 100644 index 00000000..eec76981 --- /dev/null +++ b/mlx_audio/tts/models/sesame/watermarking.py @@ -0,0 +1,105 @@ +import argparse + +import mlx.core as mx +import numpy as np +import silentcipher +import soundfile as sf +from scipy import signal + +# This watermark key is public, it is not secure. +# If using CSM 1B in another application, use a new private key and keep it secret. +CSM_1B_GH_WATERMARK = [212, 211, 146, 56, 201] + + +def cli_check_audio() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--audio_path", type=str, required=True) + args = parser.parse_args() + check_audio_from_file(args.audio_path) + + +def load_watermarker() -> silentcipher.server.Model: + model = silentcipher.get_model( + model_type="44.1k", + ) + return model + + +def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray: + gcd = np.gcd(orig_sr, target_sr) + up = target_sr // gcd + down = orig_sr // gcd + resampled = signal.resample_poly(audio, up, down, padtype="edge") + return resampled + + +def watermark( + watermarker: silentcipher.server.Model, + audio_array: mx.array, + sample_rate: int, + watermark_key: list[int], +) -> tuple[mx.array, int]: + audio_array = np.array(audio_array, dtype=np.float32) + + if sample_rate != 44100: + audio_array_44khz = resample_audio(audio_array, sample_rate, 44100) + else: + audio_array_44khz = audio_array + + encoded, *_ = watermarker.encode_wav( + audio_array_44khz, 44100, watermark_key, calc_sdr=False, message_sdr=36 + ) + + if sample_rate != 44100: + encoded = resample_audio(encoded, 44100, sample_rate) + + return encoded + + +def verify( + watermarker: silentcipher.server.Model, + watermarked_audio: mx.array, + sample_rate: int, + watermark_key: list[int], +) -> bool: + if sample_rate != 44100: + watermarked_audio_44khz = resample_audio(watermarked_audio, sample_rate, 44100) + else: + watermarked_audio_44khz = watermarked_audio + + result = watermarker.decode_wav( + watermarked_audio_44khz, 44100, phase_shift_decoding=True + ) + + is_watermarked = result["status"] + if is_watermarked: + is_csm_watermarked = result["messages"][0] == watermark_key + else: + is_csm_watermarked = False + + return is_watermarked and is_csm_watermarked + + +def check_audio_from_file(audio_path: str) -> None: + watermarker = load_watermarker() + audio_array, sample_rate = load_audio(audio_path) + is_watermarked = verify(watermarker, audio_array, sample_rate, CSM_1B_GH_WATERMARK) + outcome = "Watermarked" if is_watermarked else "Not watermarked" + print(f"{outcome}: {audio_path}") + + +def load_audio(audio_path: str) -> tuple[mx.array, int]: + audio_array_np, sample_rate = sf.read(audio_path, always_2d=True) + + if audio_array_np.shape[1] > 1: + audio_array_np = audio_array_np.mean(axis=1) + else: + audio_array_np = audio_array_np.squeeze() + + audio_array = mx.array(audio_array_np) + + return audio_array, int(sample_rate) + + +if __name__ == "__main__": + cli_check_audio() diff --git a/mlx_audio/tts/utils.py b/mlx_audio/tts/utils.py index 0a35aa9e..81df5bc0 100644 --- a/mlx_audio/tts/utils.py +++ b/mlx_audio/tts/utils.py @@ -11,7 +11,7 @@ from mlx.utils import tree_flatten, tree_unflatten from mlx_lm.utils import get_model_path, load_config, make_shards -MODEL_REMAPPING = {} +MODEL_REMAPPING = {"mlx-community/csm-1b": "sesame"} MAX_FILE_SIZE_GB = 5