|
| 1 | +use std::error::Error; |
| 2 | +use std::fs; |
| 3 | +use std::path::{Path, PathBuf}; |
| 4 | + |
| 5 | +use clap::{Parser, ValueEnum}; |
| 6 | +use hound::{SampleFormat, WavSpec, WavWriter}; |
| 7 | +use hypr_audacity::{Project, Track}; |
| 8 | +use hypr_audio_utils::{audio_file_metadata, resample_audio, source_from_path}; |
| 9 | + |
| 10 | +const TARGET_SAMPLE_RATE: u32 = 16_000; |
| 11 | + |
| 12 | +#[cfg(not(feature = "onnx"))] |
| 13 | +fn main() -> Result<(), Box<dyn Error>> { |
| 14 | + Err("the audacity example requires the `onnx` feature".into()) |
| 15 | +} |
| 16 | + |
| 17 | +#[cfg(feature = "onnx")] |
| 18 | +use aec::{AEC, BLOCK_SIZE}; |
| 19 | + |
| 20 | +#[cfg(feature = "onnx")] |
| 21 | +#[derive(Clone, ValueEnum)] |
| 22 | +enum Mode { |
| 23 | + Batch, |
| 24 | + Streaming, |
| 25 | +} |
| 26 | + |
| 27 | +#[cfg(feature = "onnx")] |
| 28 | +#[derive(Parser)] |
| 29 | +struct Args { |
| 30 | + mic: PathBuf, |
| 31 | + lpb: PathBuf, |
| 32 | + |
| 33 | + #[arg(long)] |
| 34 | + out_dir: Option<PathBuf>, |
| 35 | + |
| 36 | + #[arg(long, value_enum, default_value = "streaming")] |
| 37 | + mode: Mode, |
| 38 | + |
| 39 | + #[arg(long, default_value_t = BLOCK_SIZE * 2)] |
| 40 | + chunk_size: usize, |
| 41 | +} |
| 42 | + |
| 43 | +#[cfg(feature = "onnx")] |
| 44 | +fn main() -> Result<(), Box<dyn Error>> { |
| 45 | + let args = Args::parse(); |
| 46 | + let out_dir = args |
| 47 | + .out_dir |
| 48 | + .unwrap_or_else(|| default_out_dir(&args.mic, &args.mode)); |
| 49 | + fs::create_dir_all(&out_dir)?; |
| 50 | + |
| 51 | + let mic = load_mono_16khz(&args.mic)?; |
| 52 | + let lpb = load_mono_16khz(&args.lpb)?; |
| 53 | + let len_audio = mic.len().min(lpb.len()); |
| 54 | + let mic = mic[..len_audio].to_vec(); |
| 55 | + let lpb = lpb[..len_audio].to_vec(); |
| 56 | + |
| 57 | + let processed = run_aec(&mic, &lpb, &args.mode, args.chunk_size)?; |
| 58 | + let removed = subtract(&mic, &processed); |
| 59 | + |
| 60 | + let mic_path = out_dir.join("mic_input.wav"); |
| 61 | + let lpb_path = out_dir.join("speaker_reference.wav"); |
| 62 | + let aec_path = out_dir.join("aec_output.wav"); |
| 63 | + let removed_path = out_dir.join("cancelled_from_mic.wav"); |
| 64 | + let summary_path = out_dir.join("summary.txt"); |
| 65 | + |
| 66 | + write_wav(&mic_path, &mic)?; |
| 67 | + write_wav(&lpb_path, &lpb)?; |
| 68 | + write_wav(&aec_path, &processed)?; |
| 69 | + write_wav(&removed_path, &removed)?; |
| 70 | + write_summary( |
| 71 | + &summary_path, |
| 72 | + &args.mic, |
| 73 | + &args.lpb, |
| 74 | + &args.mode, |
| 75 | + args.chunk_size, |
| 76 | + len_audio, |
| 77 | + &mic, |
| 78 | + &processed, |
| 79 | + &removed, |
| 80 | + )?; |
| 81 | + |
| 82 | + let bundle = Project::new() |
| 83 | + .with_track(Track::new(&mic_path).with_name("mic_input")) |
| 84 | + .with_track(Track::new(&aec_path).with_name("aec_output")) |
| 85 | + .with_track( |
| 86 | + Track::new(&removed_path) |
| 87 | + .with_name("cancelled_from_mic") |
| 88 | + .muted(true), |
| 89 | + ) |
| 90 | + .with_track( |
| 91 | + Track::new(&lpb_path) |
| 92 | + .with_name("speaker_reference") |
| 93 | + .muted(true), |
| 94 | + ) |
| 95 | + .write_bundle(&out_dir)?; |
| 96 | + |
| 97 | + println!("exported {}", out_dir.display()); |
| 98 | + println!(" {}", mic_path.display()); |
| 99 | + println!(" {}", lpb_path.display()); |
| 100 | + println!(" {}", aec_path.display()); |
| 101 | + println!(" {}", removed_path.display()); |
| 102 | + println!(" {}", summary_path.display()); |
| 103 | + println!(" {}", bundle.commands_path.display()); |
| 104 | + println!(" {}", bundle.script_path.display()); |
| 105 | + println!(); |
| 106 | + println!("with Audacity pipe scripting enabled:"); |
| 107 | + println!(" python3 {}", bundle.script_path.display()); |
| 108 | + |
| 109 | + Ok(()) |
| 110 | +} |
| 111 | + |
| 112 | +#[cfg(feature = "onnx")] |
| 113 | +fn default_out_dir(mic: &Path, mode: &Mode) -> PathBuf { |
| 114 | + let stem = mic |
| 115 | + .file_stem() |
| 116 | + .and_then(|stem| stem.to_str()) |
| 117 | + .filter(|stem| !stem.is_empty()) |
| 118 | + .unwrap_or("mic"); |
| 119 | + let mode = match mode { |
| 120 | + Mode::Batch => "batch", |
| 121 | + Mode::Streaming => "streaming", |
| 122 | + }; |
| 123 | + mic.parent() |
| 124 | + .unwrap_or_else(|| Path::new(".")) |
| 125 | + .join(format!("{stem}-aec-{mode}")) |
| 126 | +} |
| 127 | + |
| 128 | +#[cfg(feature = "onnx")] |
| 129 | +fn load_mono_16khz(path: &Path) -> Result<Vec<f32>, Box<dyn Error>> { |
| 130 | + let metadata = audio_file_metadata(path)?; |
| 131 | + let channels = metadata.channels as usize; |
| 132 | + let source = source_from_path(path)?; |
| 133 | + let samples = if metadata.sample_rate == TARGET_SAMPLE_RATE { |
| 134 | + source.collect::<Vec<_>>() |
| 135 | + } else { |
| 136 | + resample_audio(source, TARGET_SAMPLE_RATE)? |
| 137 | + }; |
| 138 | + |
| 139 | + Ok(downmix_to_mono(&samples, channels)) |
| 140 | +} |
| 141 | + |
| 142 | +#[cfg(feature = "onnx")] |
| 143 | +fn downmix_to_mono(samples: &[f32], channels: usize) -> Vec<f32> { |
| 144 | + if channels <= 1 { |
| 145 | + return samples.to_vec(); |
| 146 | + } |
| 147 | + |
| 148 | + samples |
| 149 | + .chunks_exact(channels) |
| 150 | + .map(|frame| frame.iter().copied().sum::<f32>() / channels as f32) |
| 151 | + .collect() |
| 152 | +} |
| 153 | + |
| 154 | +#[cfg(feature = "onnx")] |
| 155 | +fn run_aec( |
| 156 | + mic: &[f32], |
| 157 | + lpb: &[f32], |
| 158 | + mode: &Mode, |
| 159 | + chunk_size: usize, |
| 160 | +) -> Result<Vec<f32>, Box<dyn Error>> { |
| 161 | + let mut aec = AEC::new()?; |
| 162 | + |
| 163 | + match mode { |
| 164 | + Mode::Batch => Ok(aec.process(mic, lpb)?), |
| 165 | + Mode::Streaming => { |
| 166 | + let mut output = Vec::with_capacity(mic.len()); |
| 167 | + let chunk_size = chunk_size.max(1); |
| 168 | + let mut processed = 0; |
| 169 | + |
| 170 | + while processed < mic.len() { |
| 171 | + let end = (processed + chunk_size).min(mic.len()); |
| 172 | + output.extend(aec.process_streaming(&mic[processed..end], &lpb[processed..end])?); |
| 173 | + processed = end; |
| 174 | + } |
| 175 | + |
| 176 | + Ok(output) |
| 177 | + } |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +#[cfg(feature = "onnx")] |
| 182 | +fn subtract(input: &[f32], output: &[f32]) -> Vec<f32> { |
| 183 | + input |
| 184 | + .iter() |
| 185 | + .zip(output.iter()) |
| 186 | + .map(|(input, output)| (input - output).clamp(-1.0, 1.0)) |
| 187 | + .collect() |
| 188 | +} |
| 189 | + |
| 190 | +#[cfg(feature = "onnx")] |
| 191 | +fn write_wav(path: &Path, samples: &[f32]) -> Result<(), Box<dyn Error>> { |
| 192 | + let spec = WavSpec { |
| 193 | + channels: 1, |
| 194 | + sample_rate: TARGET_SAMPLE_RATE, |
| 195 | + bits_per_sample: 32, |
| 196 | + sample_format: SampleFormat::Float, |
| 197 | + }; |
| 198 | + let mut writer = WavWriter::create(path, spec)?; |
| 199 | + for sample in samples { |
| 200 | + writer.write_sample(*sample)?; |
| 201 | + } |
| 202 | + writer.finalize()?; |
| 203 | + Ok(()) |
| 204 | +} |
| 205 | + |
| 206 | +#[cfg(feature = "onnx")] |
| 207 | +fn write_summary( |
| 208 | + path: &Path, |
| 209 | + mic_path: &Path, |
| 210 | + lpb_path: &Path, |
| 211 | + mode: &Mode, |
| 212 | + chunk_size: usize, |
| 213 | + total_samples: usize, |
| 214 | + mic: &[f32], |
| 215 | + processed: &[f32], |
| 216 | + removed: &[f32], |
| 217 | +) -> Result<(), Box<dyn Error>> { |
| 218 | + let body = format!( |
| 219 | + "mic={}\nlpb={}\nmode={}\nchunk_size={}\nsample_rate={}\nduration_sec={:.3}\nmic_rms={:.6}\naec_rms={:.6}\nremoved_rms={:.6}\n", |
| 220 | + mic_path.display(), |
| 221 | + lpb_path.display(), |
| 222 | + match mode { |
| 223 | + Mode::Batch => "batch", |
| 224 | + Mode::Streaming => "streaming", |
| 225 | + }, |
| 226 | + chunk_size, |
| 227 | + TARGET_SAMPLE_RATE, |
| 228 | + total_samples as f64 / TARGET_SAMPLE_RATE as f64, |
| 229 | + rms(mic), |
| 230 | + rms(processed), |
| 231 | + rms(removed), |
| 232 | + ); |
| 233 | + |
| 234 | + fs::write(path, body)?; |
| 235 | + Ok(()) |
| 236 | +} |
| 237 | + |
| 238 | +#[cfg(feature = "onnx")] |
| 239 | +fn rms(samples: &[f32]) -> f32 { |
| 240 | + if samples.is_empty() { |
| 241 | + return 0.0; |
| 242 | + } |
| 243 | + let sum_sq: f32 = samples.iter().map(|sample| sample * sample).sum(); |
| 244 | + (sum_sq / samples.len() as f32).sqrt() |
| 245 | +} |
0 commit comments