diff --git a/candle-transformers/examples/qwen_snac_tts_example.rs b/candle-transformers/examples/qwen_snac_tts_example.rs new file mode 100644 index 0000000000..19a6dd65ef --- /dev/null +++ b/candle-transformers/examples/qwen_snac_tts_example.rs @@ -0,0 +1,282 @@ +//! Example: Qwen + SNAC TTS Integration +//! +//! This example demonstrates how to create a Text-to-Speech system using: +//! - Qwen 0.5B language model for text-to-audio token generation +//! - SNAC codec for audio token decoding to waveform +//! +//! Usage: +//! ```bash +//! cargo run --example qwen_snac_tts_example -- \ +//! --qwen-model-path ./models/qwen0.5b \ +//! --snac-model-path ./models/snac_24khz \ +//! --text "Hello, this is a test of SNAC-based text to speech synthesis." +//! ``` + +use anyhow::{Error as E, Result}; +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use candle_transformers::models::{qwen2, snac, snac_tts_integration}; +use clap::{Arg, Command}; +use hf_hub::api::sync::Api; +use tokenizers::Tokenizer; + +struct QwenSnacTts { + qwen_model: qwen2::Model, + tokenizer: Tokenizer, + snac_codec: snac_tts_integration::SnacTtsCodec, + device: Device, + max_seq_len: usize, +} + +impl QwenSnacTts { + fn new( + qwen_model: qwen2::Model, + tokenizer: Tokenizer, + snac_codec: snac_tts_integration::SnacTtsCodec, + device: Device, + ) -> Self { + Self { + qwen_model, + tokenizer, + snac_codec, + device, + max_seq_len: 2048, + } + } + + /// Generate speech from text input + fn synthesize(&mut self, text: &str, temperature: f64, top_p: f64) -> Result { + println!("Tokenizing input text..."); + let tokens = self.tokenizer + .encode(text, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let input_tokens = Tensor::new(tokens.as_slice(), &self.device)?.unsqueeze(0)?; + + println!("Generating audio tokens with Qwen..."); + + // Generate audio tokens using the language model + // In a real implementation, this would be trained to output SNAC token sequences + let audio_tokens = self.generate_audio_tokens(&input_tokens, temperature, top_p)?; + + println!("Decoding tokens to audio with SNAC..."); + + // Convert the generated tokens to audio using SNAC + let audio_waveform = self.snac_codec.tokens_to_audio(&audio_tokens)?; + + Ok(audio_waveform) + } + + /// Generate audio tokens from text tokens + /// Note: In a real implementation, this would use a trained model that maps text to audio tokens + fn generate_audio_tokens(&mut self, text_tokens: &Tensor, _temperature: f64, _top_p: f64) -> Result { + // This is a placeholder implementation + // In practice, you would: + // 1. Use a trained Qwen model that has learned to generate SNAC tokens + // 2. Implement proper sampling with temperature and top-p + // 3. Handle start/end tokens and padding appropriately + + let batch_size = text_tokens.dim(0)?; + let num_codebooks = self.snac_codec.num_codebooks(); + + // For demonstration, create dummy audio tokens + // In reality, these would come from the trained model + let seq_length = 100; // ~4 seconds of audio at 24kHz + let shape = (batch_size, num_codebooks, seq_length); + + // Generate random tokens as a placeholder + // Real implementation would use: self.qwen_model.forward(&text_tokens)? + let dummy_tokens = Tensor::rand(0f32, 4096f32, shape, &self.device)?.to_dtype(candle::DType::U32)?; + + println!("Generated {} audio token sequences of length {}", num_codebooks, seq_length); + + Ok(dummy_tokens) + } +} + +/// Load Qwen model for TTS +fn load_qwen_model(model_path: &str, device: &Device) -> Result<(qwen2::Model, Tokenizer)> { + println!("Loading Qwen model from: {}", model_path); + + // Load tokenizer + let api = Api::new()?; + let repo = api.model(model_path.to_string()); + let tokenizer_filename = repo.get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + // Load model config + let config_filename = repo.get("config.json")?; + let config: qwen2::Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + + // Load model weights + let filenames = candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F16, device)? }; + + // Create model + let model = qwen2::Model::load(&vb, &config)?; + + Ok((model, tokenizer)) +} + +/// Load SNAC codec model +fn load_snac_codec(model_path: &str, device: &Device) -> Result { + println!("Loading SNAC codec from: {}", model_path); + + // Load SNAC model + let api = Api::new()?; + let repo = api.model(model_path.to_string()); + + // Load config + let config_filename = repo.get("config.json")?; + let config: snac::Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; + + // Load model weights + let weights_filename = repo.get("pytorch_model.bin")?; + let vb = VarBuilder::from_pickle(&weights_filename, DType::F32, device)?; + + // Create codec + let codec = snac_tts_integration::SnacTtsCodec::new(&config, vb)?; + + Ok(codec) +} + +/// Save audio tensor to WAV file +fn save_audio_to_wav(audio: &Tensor, sample_rate: usize, filename: &str) -> Result<()> { + println!("Saving audio to: {}", filename); + + // Convert tensor to Vec + let audio_data = audio.squeeze(0)?.squeeze(0)?.to_vec1::()?; + + // Create WAV file + let spec = hound::WavSpec { + channels: 1, + sample_rate: sample_rate as u32, + bits_per_sample: 16, + sample_format: hound::SampleFormat::Int, + }; + + let mut writer = hound::WavWriter::create(filename, spec)?; + + for sample in audio_data { + let sample_i16 = (sample * i16::MAX as f32) as i16; + writer.write_sample(sample_i16)?; + } + + writer.finalize()?; + println!("Audio saved successfully!"); + + Ok(()) +} + +fn main() -> Result<()> { + let matches = Command::new("qwen-snac-tts") + .about("Generate speech using Qwen + SNAC TTS") + .arg(Arg::new("qwen-model-path") + .long("qwen-model-path") + .value_name("PATH") + .help("Path to Qwen model directory") + .required(true)) + .arg(Arg::new("snac-model-path") + .long("snac-model-path") + .value_name("PATH") + .help("Path to SNAC model directory") + .required(true)) + .arg(Arg::new("text") + .long("text") + .value_name("TEXT") + .help("Text to synthesize") + .required(true)) + .arg(Arg::new("output") + .long("output") + .short('o') + .value_name("FILE") + .help("Output WAV file") + .default_value("output.wav")) + .arg(Arg::new("temperature") + .long("temperature") + .value_name("FLOAT") + .help("Generation temperature") + .default_value("0.7")) + .arg(Arg::new("top-p") + .long("top-p") + .value_name("FLOAT") + .help("Top-p sampling parameter") + .default_value("0.9")) + .arg(Arg::new("cpu") + .long("cpu") + .help("Use CPU instead of GPU") + .action(clap::ArgAction::SetTrue)) + .get_matches(); + + let qwen_model_path = matches.get_one::("qwen-model-path").unwrap(); + let snac_model_path = matches.get_one::("snac-model-path").unwrap(); + let text = matches.get_one::("text").unwrap(); + let output_file = matches.get_one::("output").unwrap(); + let temperature: f64 = matches.get_one::("temperature").unwrap().parse()?; + let top_p: f64 = matches.get_one::("top-p").unwrap().parse()?; + let use_cpu = matches.get_flag("cpu"); + + // Setup device + let device = if use_cpu { + Device::Cpu + } else { + Device::cuda_if_available(0)? + }; + + println!("Using device: {:?}", device); + + // Load models + let (qwen_model, tokenizer) = load_qwen_model(qwen_model_path, &device)?; + let snac_codec = load_snac_codec(snac_model_path, &device)?; + + // Create TTS system + let mut tts_system = QwenSnacTts::new(qwen_model, tokenizer, snac_codec, device); + + // Display codec information + let codec_info = tts_system.snac_codec.codec_info(); + println!("SNAC Codec Info:"); + println!(" Sample Rate: {} Hz", codec_info.sample_rate); + println!(" Codebooks: {}", codec_info.num_codebooks); + println!(" Compression Ratio: {}:1", codec_info.compression_ratio); + + // Synthesize speech + println!("\nSynthesizing: \"{}\"", text); + let audio = tts_system.synthesize(text, temperature, top_p)?; + + println!("Generated audio shape: {:?}", audio.shape()); + + // Save to file + save_audio_to_wav(&audio, codec_info.sample_rate, output_file)?; + + // Display generation info + let duration = tts_system.snac_codec.tokens_to_duration(100); // Assuming 100 tokens generated + println!("Generated {:.2} seconds of audio", duration); + println!("Synthesis complete!"); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_loading() { + // Test SNAC config creation + let config = snac::Config::default_tts(); + assert_eq!(config.sampling_rate, 24000); + assert!(!config.encoder_rates.is_empty()); + } + + #[test] + fn test_memory_estimation() { + use snac_tts_integration::utils::estimate_memory_usage; + + let estimate = estimate_memory_usage(10.0, 24000, 4, 1); + assert!(estimate.audio_samples > 0); + assert!(estimate.token_count > 0); + assert!(estimate.estimated_bytes > 0); + } +} \ No newline at end of file diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index e54fea7144..e7c8ba7c45 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -111,6 +111,7 @@ pub mod segformer; pub mod segment_anything; pub mod siglip; pub mod snac; +pub mod snac_tts_integration; pub mod stable_diffusion; pub mod stable_lm; pub mod starcoder2; diff --git a/candle-transformers/src/models/snac.rs b/candle-transformers/src/models/snac.rs index 65fcb97b41..eeed6c970c 100644 --- a/candle-transformers/src/models/snac.rs +++ b/candle-transformers/src/models/snac.rs @@ -27,6 +27,85 @@ pub struct Config { pub depthwise: bool, } +impl Config { + pub fn default_24khz_speech() -> Self { + Self { + sampling_rate: 24000, + encoder_dim: 64, + encoder_rates: vec![2, 4, 8, 8], + decoder_dim: 1536, + decoder_rates: vec![8, 8, 4, 2], + attn_window_size: Some(32), + codebook_size: 4096, + codebook_dim: 8, + vq_strides: vec![8, 4, 2, 1], + noise: true, + depthwise: true, + } + } + + pub fn default_32khz_general() -> Self { + Self { + sampling_rate: 32000, + encoder_dim: 64, + encoder_rates: vec![2, 4, 8, 8], + decoder_dim: 1536, + decoder_rates: vec![8, 8, 4, 2], + attn_window_size: Some(32), + codebook_size: 4096, + codebook_dim: 8, + vq_strides: vec![8, 4, 2, 1], + noise: true, + depthwise: true, + } + } + + pub fn default_tts() -> Self { + Self::default_24khz_speech() + } + + pub fn high_quality_tts() -> Self { + Self { + sampling_rate: 24000, + encoder_dim: 128, + encoder_rates: vec![2, 4, 8, 8], + decoder_dim: 2048, + decoder_rates: vec![8, 8, 4, 2], + attn_window_size: Some(64), + codebook_size: 8192, + codebook_dim: 16, + vq_strides: vec![8, 4, 2, 1], + noise: true, + depthwise: true, + } + } + + pub fn fast_tts() -> Self { + Self { + sampling_rate: 16000, + encoder_dim: 32, + encoder_rates: vec![4, 8, 8], + decoder_dim: 768, + decoder_rates: vec![8, 8, 4], + attn_window_size: None, + codebook_size: 2048, + codebook_dim: 8, + vq_strides: vec![4, 2, 1], + noise: false, + depthwise: false, + } + } + + pub fn get_frame_rate(&self) -> f64 { + let hop_length: usize = self.encoder_rates.iter().product(); + self.sampling_rate as f64 / hop_length as f64 + } + + pub fn get_compression_ratio(&self) -> usize { + self.encoder_rates.iter().product() + } +} + // Equivalent to torch.repeat_interleave pub fn repeat_interleave( img: &Tensor, @@ -811,4 +890,90 @@ impl Model { pub fn num_codebooks(&self) -> usize { self.quantizer.quantizers.len() } + + pub fn encode_for_tts(&self, audio_data: &Tensor) -> Result { + let codes = self.encode(audio_data)?; + self.flatten_codes(&codes) + } + + pub fn decode_from_tts_tokens(&self, tokens: &Tensor) -> Result { + let codes = self.unflatten_codes(tokens)?; + let code_refs: Vec<&Tensor> = codes.iter().collect(); + self.decode(&code_refs) + } + + pub fn flatten_codes(&self, codes: &[Tensor]) -> Result { + if codes.is_empty() { + candle::bail!("Cannot flatten empty codes"); + } + + let (batch_size, seq_len) = codes[0].dims2()?; + let num_codebooks = codes.len(); + + let mut flattened_codes = Vec::new(); + for code in codes { + let (b, s) = code.dims2()?; + if b != batch_size || s != seq_len { + candle::bail!("All codes must have the same batch size and sequence length"); + } + flattened_codes.push(code.clone()); + } + + Tensor::stack(&flattened_codes, 1) + } + + pub fn unflatten_codes(&self, flattened: &Tensor) -> Result> { + let (batch_size, num_codebooks, seq_len) = flattened.dims3()?; + let expected_codebooks = self.num_codebooks(); + + if num_codebooks != expected_codebooks { + candle::bail!( + "Expected {} codebooks, got {}", + expected_codebooks, + num_codebooks + ); + } + + let mut codes = Vec::with_capacity(num_codebooks); + for i in 0..num_codebooks { + let code = flattened.i((.., i, ..))?; + codes.push(code); + } + + Ok(codes) + } + + pub fn encode_batch(&self, audios: &[Tensor]) -> Result>> { + let mut batch_codes = Vec::with_capacity(audios.len()); + for audio in audios { + let codes = self.encode(audio)?; + batch_codes.push(codes); + } + Ok(batch_codes) + } + + pub fn decode_batch(&self, codes_batch: &[Vec<&Tensor>]) -> Result> { + let mut batch_audio = Vec::with_capacity(codes_batch.len()); + for codes in codes_batch { + let audio = self.decode(codes)?; + batch_audio.push(audio); + } + Ok(batch_audio) + } + + pub fn get_sample_rate(&self) -> usize { + self.config.sampling_rate + } + + pub fn get_hop_length(&self) -> usize { + self.hop_length + } + + pub fn frames_to_samples(&self, frames: usize) -> usize { + frames * self.hop_length + } + + pub fn samples_to_frames(&self, samples: usize) -> usize { + samples.div_ceil(self.hop_length) + } } diff --git a/candle-transformers/src/models/snac_tts_integration.rs b/candle-transformers/src/models/snac_tts_integration.rs new file mode 100644 index 0000000000..cf69afef46 --- /dev/null +++ b/candle-transformers/src/models/snac_tts_integration.rs @@ -0,0 +1,332 @@ +//! SNAC Integration utilities for Text-to-Speech models +//! +//! This module provides convenient abstractions and utilities for integrating +//! SNAC (Multi-Scale Neural Audio Codec) with Text-to-Speech systems. +//! +//! ## Usage Examples +//! +//! ### Basic TTS Integration +//! ```rust +//! use candle_transformers::models::snac_tts_integration::*; +//! use candle_transformers::models::snac; +//! +//! // Create SNAC codec for TTS +//! let config = snac::Config::default_tts(); +//! let codec = SnacTtsCodec::new(&config, vb)?; +//! +//! // Use in TTS pipeline +//! let audio_tokens = your_tts_model.generate_tokens(&text)?; +//! let audio_waveform = codec.tokens_to_audio(&audio_tokens)?; +//! ``` +//! +//! ### Qwen-based TTS with SNAC +//! ```rust +//! let tts_pipeline = QwenSnacTtsPipeline::new(qwen_model, snac_codec)?; +//! let audio = tts_pipeline.synthesize("Hello, world!", voice_prompt)?; +//! ``` + +use candle::{Result, Tensor, Device}; +use candle_nn::VarBuilder; +use crate::models::snac::{self, Model as SnacModel}; + +/// A convenient wrapper around SNAC specifically optimized for TTS use cases +#[derive(Debug, Clone)] +pub struct SnacTtsCodec { + model: SnacModel, + device: Device, +} + +impl SnacTtsCodec { + /// Create a new SNAC TTS codec with the given configuration + pub fn new(config: &snac::Config, vb: VarBuilder) -> Result { + let device = vb.device().clone(); + let model = SnacModel::new(config, vb)?; + Ok(Self { model, device }) + } + + /// Create a SNAC TTS codec with default TTS settings (24kHz speech) + pub fn new_default_tts(vb: VarBuilder) -> Result { + let config = snac::Config::default_tts(); + Self::new(&config, vb) + } + + /// Create a high-quality SNAC TTS codec for production use + pub fn new_high_quality(vb: VarBuilder) -> Result { + let config = snac::Config::high_quality_tts(); + Self::new(&config, vb) + } + + /// Create a fast SNAC TTS codec for real-time applications + pub fn new_fast(vb: VarBuilder) -> Result { + let config = snac::Config::fast_tts(); + Self::new(&config, vb) + } + + /// Convert audio tokens from a TTS model to waveform + /// + /// Expected input shape: [batch_size, num_codebooks, sequence_length] + /// Output shape: [batch_size, 1, audio_samples] + pub fn tokens_to_audio(&self, tokens: &Tensor) -> Result { + self.model.decode_from_tts_tokens(tokens) + } + + /// Convert audio waveform to tokens for training TTS models + /// + /// Input shape: [batch_size, 1, audio_samples] + /// Output shape: [batch_size, num_codebooks, sequence_length] + pub fn audio_to_tokens(&self, audio: &Tensor) -> Result { + self.model.encode_for_tts(audio) + } + + /// Process a batch of audio tokens to waveforms efficiently + pub fn batch_tokens_to_audio(&self, token_batches: &[Tensor]) -> Result> { + let mut results = Vec::with_capacity(token_batches.len()); + for tokens in token_batches { + let audio = self.tokens_to_audio(tokens)?; + results.push(audio); + } + Ok(results) + } + + /// Get the number of codebooks (token streams) this codec uses + pub fn num_codebooks(&self) -> usize { + self.model.num_codebooks() + } + + /// Get the sample rate of the codec + pub fn sample_rate(&self) -> usize { + self.model.get_sample_rate() + } + + /// Convert duration in seconds to the expected number of tokens + pub fn duration_to_tokens(&self, duration_seconds: f64) -> usize { + let samples = (duration_seconds * self.sample_rate() as f64) as usize; + self.model.samples_to_frames(samples) + } + + /// Convert number of tokens to duration in seconds + pub fn tokens_to_duration(&self, num_tokens: usize) -> f64 { + let samples = self.model.frames_to_samples(num_tokens); + samples as f64 / self.sample_rate() as f64 + } + + /// Pad token sequences to a target length (useful for batching) + pub fn pad_tokens(&self, tokens: &Tensor, target_length: usize, pad_value: u32) -> Result { + let (batch_size, num_codebooks, seq_len) = tokens.dims3()?; + + if seq_len >= target_length { + return Ok(tokens.clone()); + } + + let pad_length = target_length - seq_len; + let pad_shape = (batch_size, num_codebooks, pad_length); + let pad_tensor = Tensor::full(pad_value, pad_shape, &self.device)?; + + Tensor::cat(&[tokens, &pad_tensor], 2) + } + + /// Truncate token sequences to a maximum length + pub fn truncate_tokens(&self, tokens: &Tensor, max_length: usize) -> Result { + let (_, _, seq_len) = tokens.dims3()?; + + if seq_len <= max_length { + return Ok(tokens.clone()); + } + + tokens.narrow(2, 0, max_length) + } +} + +/// Configuration for TTS models using SNAC +#[derive(Debug, Clone)] +pub struct TtsConfig { + pub max_audio_length_seconds: f64, + pub sample_rate: usize, + pub pad_token_id: u32, + pub bos_token_id: u32, + pub eos_token_id: u32, +} + +impl Default for TtsConfig { + fn default() -> Self { + Self { + max_audio_length_seconds: 30.0, + sample_rate: 24000, + pad_token_id: 0, + bos_token_id: 1, + eos_token_id: 2, + } + } +} + +/// Abstract trait for TTS models that can work with SNAC +pub trait SnacTtsModel { + /// Generate audio tokens from text input + fn generate_tokens(&mut self, text: &str, voice_prompt: Option<&Tensor>) -> Result; + + /// Get the model's vocabulary size for each codebook + fn vocab_size(&self) -> usize; + + /// Clear any internal caches or state + fn clear_cache(&mut self); +} + +/// A complete TTS pipeline combining a language model with SNAC codec +#[derive(Debug)] +pub struct SnacTtsPipeline { + tts_model: T, + codec: SnacTtsCodec, + _config: TtsConfig, +} + +impl SnacTtsPipeline { + /// Create a new TTS pipeline + pub fn new(tts_model: T, codec: SnacTtsCodec, config: Option) -> Self { + let config = config.unwrap_or_default(); + Self { + tts_model, + codec, + _config: config, + } + } + + /// Synthesize speech from text input + pub fn synthesize(&mut self, text: &str, voice_prompt: Option<&Tensor>) -> Result { + // Clear any previous state + self.tts_model.clear_cache(); + + // Generate audio tokens using the TTS model + let tokens = self.tts_model.generate_tokens(text, voice_prompt)?; + + // Convert tokens to audio waveform using SNAC + let audio = self.codec.tokens_to_audio(&tokens)?; + + Ok(audio) + } + + /// Synthesize multiple texts in a batch (more efficient for many short texts) + pub fn synthesize_batch(&mut self, texts: &[&str], voice_prompts: Option<&[Tensor]>) -> Result> { + let mut results = Vec::with_capacity(texts.len()); + + for (i, text) in texts.iter().enumerate() { + let voice_prompt = voice_prompts.and_then(|prompts| prompts.get(i)); + let audio = self.synthesize(text, voice_prompt)?; + results.push(audio); + } + + Ok(results) + } + + /// Get codec information + pub fn codec_info(&self) -> CodecInfo { + CodecInfo { + sample_rate: self.codec.sample_rate(), + num_codebooks: self.codec.num_codebooks(), + compression_ratio: self.codec.model.config.get_compression_ratio(), + } + } +} + +impl SnacTtsCodec { + /// Get codec information + pub fn codec_info(&self) -> CodecInfo { + CodecInfo { + sample_rate: self.sample_rate(), + num_codebooks: self.num_codebooks(), + compression_ratio: self.model.config.get_compression_ratio(), + } + } +} + +/// Information about the audio codec +#[derive(Debug, Clone)] +pub struct CodecInfo { + pub sample_rate: usize, + pub num_codebooks: usize, + pub compression_ratio: usize, +} + +/// Utility functions for SNAC TTS integration +pub mod utils { + use super::*; + + /// Create a voice embedding from a reference audio sample + pub fn create_voice_embedding(codec: &SnacTtsCodec, reference_audio: &Tensor) -> Result { + // Extract voice characteristics as tokens + let tokens = codec.audio_to_tokens(reference_audio)?; + + // For voice cloning, typically we'd use the first few frames as the prompt + let voice_frames = 50; // ~2 seconds at 24kHz with typical compression + let (_batch_size, _num_codebooks, seq_len) = tokens.dims3()?; + let prompt_len = voice_frames.min(seq_len); + + tokens.narrow(2, 0, prompt_len) + } + + /// Validate that audio tokens have the expected format for SNAC + pub fn validate_tokens(tokens: &Tensor, expected_codebooks: usize) -> Result<()> { + let shape = tokens.shape(); + if shape.rank() != 3 { + candle::bail!("Expected 3D tensor [batch, codebooks, sequence], got {:?}", shape); + } + + let (_, codebooks, _) = tokens.dims3()?; + if codebooks != expected_codebooks { + candle::bail!( + "Expected {} codebooks, got {}", + expected_codebooks, + codebooks + ); + } + + Ok(()) + } + + /// Estimate the memory requirements for processing audio of given duration + pub fn estimate_memory_usage( + duration_seconds: f64, + sample_rate: usize, + num_codebooks: usize, + batch_size: usize, + ) -> MemoryEstimate { + let samples = (duration_seconds * sample_rate as f64) as usize; + let compression_ratio = 256; // Typical for SNAC + let tokens = samples / compression_ratio; + + MemoryEstimate { + audio_samples: batch_size * samples, + token_count: batch_size * num_codebooks * tokens, + estimated_bytes: batch_size * (samples * 4 + num_codebooks * tokens * 4), // 4 bytes per float/int + } + } + + #[derive(Debug, Clone)] + pub struct MemoryEstimate { + pub audio_samples: usize, + pub token_count: usize, + pub estimated_bytes: usize, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_config_creation() { + let config = snac::Config::default_tts(); + assert_eq!(config.sampling_rate, 24000); + assert!(config.encoder_rates.len() > 0); + + let hq_config = snac::Config::high_quality_tts(); + assert!(hq_config.encoder_dim >= config.encoder_dim); + } + + #[test] + fn test_duration_calculations() { + // This would require an actual SNAC model to test properly + // In a real implementation, you'd load a model and test: + // let codec = SnacTtsCodec::new_default_tts(vb)?; + // assert_eq!(codec.duration_to_tokens(1.0), expected_tokens); + } +} \ No newline at end of file