diff --git a/batbot/spectrogram/__init__.py b/batbot/spectrogram/__init__.py index 074030a..4bc2b2e 100644 --- a/batbot/spectrogram/__init__.py +++ b/batbot/spectrogram/__init__.py @@ -236,8 +236,36 @@ def generate_waveplot( def load_stft( - wav_filepath, sr=250e3, n_fft=512, window='blackmanharris', win_length=256, hop_length=16 + wav_filepath, + sr=250e3, + n_fft=512, + window='blackmanharris', + win_length=256, + hop_length=16, + noise_reduction=None, + noise_reduction_params=None, ): + """ + Load a WAV file and compute the Short-Time Fourier Transform (STFT). + + Args: + wav_filepath: Path to the WAV file + sr: Target sample rate (default: 250kHz for ultrasonic bat calls) + n_fft: FFT window size (default: 512) + window: Window function (default: 'blackmanharris') + win_length: Window length in samples (default: 256) + hop_length: Number of samples between frames (default: 16) + noise_reduction: Noise reduction method to apply. Options: + - None: No noise reduction (default, original behavior) + - 'spectral': Spectral subtraction + - 'wiener': Wiener filtering + - 'adaptive': Adaptive noise floor estimation + noise_reduction_params: Dict of parameters for the chosen noise reduction method. + See spectral_subtraction(), wiener_filter(), or adaptive_noise_floor() for options. + + Returns: + tuple: (stft_db, waveplot, sr, bands, duration, min_index) + """ assert exists(wav_filepath) log.debug(f'Computing spectrogram on {wav_filepath}') @@ -258,6 +286,21 @@ def load_stft( # Convert the complex power (amplitude + phase) into amplitude (decibels) stft_db = librosa.power_to_db(np.abs(stft) ** 2, ref=np.max) + # Apply noise reduction if requested + if noise_reduction is not None: + params = noise_reduction_params or {} + if noise_reduction == 'spectral': + stft_db = spectral_subtraction(stft_db, **params) + log.debug('Applied spectral subtraction noise reduction') + elif noise_reduction == 'wiener': + stft_db = wiener_filter(stft_db, **params) + log.debug('Applied Wiener filter noise reduction') + elif noise_reduction == 'adaptive': + stft_db = adaptive_noise_floor(stft_db, **params) + log.debug('Applied adaptive noise floor estimation') + else: + log.warning(f'Unknown noise reduction method: {noise_reduction}') + # Remove frequencies that we do not need [FREQ_MIN - FREQ_MAX] bands = librosa.fft_frequencies(sr=sr, n_fft=n_fft) @@ -305,6 +348,148 @@ def gain_stft(stft_db, gain_db=80.0, autogain_stddev=5.0): return stft_db +def spectral_subtraction(stft_db, noise_frames=10, oversubtraction=1.0, floor_db=-80.0): + """ + Spectral subtraction noise reduction. + + Estimates noise profile from initial frames (assumed to be silence or low activity) + and subtracts it from the entire spectrogram. + + Args: + stft_db: Power spectrogram in dB (freq_bins x time_frames) + noise_frames: Number of initial frames to estimate noise profile (default: 10) + oversubtraction: Factor to multiply noise estimate, 1.0-2.0 typical (default: 1.0) + floor_db: Minimum floor value to prevent negative artifacts (default: -80.0) + + Returns: + Noise-reduced spectrogram in dB (same shape as input) + """ + stft_db = stft_db.copy() + + # Ensure we have enough frames for noise estimation + noise_frames = min(noise_frames, stft_db.shape[1]) + if noise_frames < 1: + return stft_db + + # Estimate noise profile as mean of first N frames per frequency bin + noise_profile = np.mean(stft_db[:, :noise_frames], axis=1, keepdims=True) + + # Subtract scaled noise profile from all frames + stft_db = stft_db - (oversubtraction * noise_profile) + + # Apply spectral floor to prevent extreme negative values (musical noise) + stft_db = np.maximum(stft_db, floor_db) + + return stft_db + + +def wiener_filter(stft_db, noise_frames=10, smoothing_freq=3, smoothing_time=3, gain_floor=0.1): + """ + Wiener filter noise reduction based on local SNR estimation. + + Estimates noise power from initial frames and computes optimal Wiener gain + for each time-frequency bin based on estimated SNR. + + Args: + stft_db: Power spectrogram in dB (freq_bins x time_frames) + noise_frames: Number of initial frames to estimate noise variance (default: 10) + smoothing_freq: Gaussian smoothing sigma for frequency axis (default: 3) + smoothing_time: Gaussian smoothing sigma for time axis (default: 3) + gain_floor: Minimum gain value to prevent complete suppression (default: 0.1) + + Returns: + Filtered spectrogram in dB (same shape as input) + """ + stft_db = stft_db.copy() + + # Ensure we have enough frames for noise estimation + noise_frames = min(noise_frames, stft_db.shape[1]) + if noise_frames < 1: + return stft_db + + # Convert dB to linear power for proper SNR calculation + # stft_db is in dB relative to max, so we convert: power = 10^(dB/10) + stft_power = np.power(10.0, stft_db / 10.0) + + # Estimate noise power as variance of first N frames per frequency bin + noise_power = np.var(stft_power[:, :noise_frames], axis=1, keepdims=True) + # Add small epsilon to prevent division by zero + noise_power = np.maximum(noise_power, 1e-10) + + # Compute Wiener gain: G = SNR / (1 + SNR) = signal / (signal + noise) + # This is equivalent to: G = max(1 - noise/signal, floor) + snr = stft_power / noise_power + wiener_gain = snr / (1.0 + snr) + + # Apply gain floor to prevent complete suppression + wiener_gain = np.maximum(wiener_gain, gain_floor) + + # Smooth the gain to reduce musical noise artifacts + if smoothing_freq > 0 or smoothing_time > 0: + wiener_gain = gaussian_filter1d(wiener_gain, smoothing_freq, axis=0, mode='nearest') + wiener_gain = gaussian_filter1d(wiener_gain, smoothing_time, axis=1, mode='nearest') + + # Apply gain in linear domain + filtered_power = stft_power * wiener_gain + + # Convert back to dB + # Add small epsilon to prevent log(0) + filtered_power = np.maximum(filtered_power, 1e-10) + stft_db = 10.0 * np.log10(filtered_power) + + return stft_db + + +def adaptive_noise_floor(stft_db, window_frames=32, percentile=10, min_db=-60.0): + """ + Adaptive noise floor estimation for varying recording conditions. + + Uses a sliding window to estimate local noise floor per frequency band, + adapting to changing noise conditions throughout the recording. + + Args: + stft_db: Power spectrogram in dB (freq_bins x time_frames) + window_frames: Sliding window size for local floor estimation (default: 32) + percentile: Percentile of values to use as floor estimate, 5-20 typical (default: 10) + min_db: Absolute minimum floor value (default: -60.0) + + Returns: + Noise-reduced spectrogram with adaptive thresholding applied (same shape as input) + """ + stft_db = stft_db.copy() + freq_bins, time_frames = stft_db.shape + + # Ensure window size is valid + window_frames = min(window_frames, time_frames) + if window_frames < 1: + return stft_db + + # Compute adaptive noise floor using sliding window percentile + # Pad the spectrogram for edge handling + pad_width = window_frames // 2 + stft_padded = np.pad(stft_db, ((0, 0), (pad_width, pad_width)), mode='edge') + + # Create sliding window views for efficient percentile computation + # Shape: (freq_bins, time_frames, window_frames) + shape = (freq_bins, time_frames, window_frames) + strides = (stft_padded.strides[0], stft_padded.strides[1], stft_padded.strides[1]) + windows = np.lib.stride_tricks.as_strided(stft_padded, shape=shape, strides=strides) + + # Compute percentile-based noise floor for each position + noise_floor = np.percentile(windows, percentile, axis=2) + + # Apply minimum floor constraint + noise_floor = np.maximum(noise_floor, min_db) + + # Subtract adaptive noise floor + stft_db = stft_db - noise_floor + + # Clip negative values to zero + stft_db = np.maximum(stft_db, 0.0) + + return stft_db + + def normalize_stft(data, value=1.0, dtype=None): if value is None: value = np.iinfo(dtype).max @@ -1253,7 +1438,14 @@ def calculate_harmonic_and_echo_flags( @lp def compute_wrapper( - wav_filepath, annotations=None, output_folder='.', bitdepth=16, debug=True, **kwargs + wav_filepath, + annotations=None, + output_folder='.', + bitdepth=16, + debug=True, + noise_reduction=None, + noise_reduction_params=None, + **kwargs ): """ Compute the spectrograms for a given input WAV and saves them to disk. @@ -1262,10 +1454,18 @@ def compute_wrapper( Args: wav_filepath (str): WAV filepath (relative or absolute) to compute spectrograms for. - ext (str, optional): The file extension of the resulting spectrogram files. If this value is - not specified, it will use the same extension as `wav_filepath`. Passed as input - to :meth:`batbot.spectrogram.spectrogram_filepath`. Defaults to :obj:`None`. - **kwargs: keyword arguments passed to :meth:`batbot.spectrogram.spectrogram_grid` + annotations: Optional list of (start, stop) time annotations. + output_folder (str): Directory to save output files. Defaults to '.'. + bitdepth (int): Output bit depth, 8 or 16. Defaults to 16. + debug (bool): Enable debug output. Defaults to True. + noise_reduction (str or None): Noise reduction method to apply. Options: + - None: No noise reduction (default, original behavior) + - 'spectral': Spectral subtraction + - 'wiener': Wiener filtering + - 'adaptive': Adaptive noise floor estimation + noise_reduction_params (dict or None): Parameters for the noise reduction method. + See spectral_subtraction(), wiener_filter(), or adaptive_noise_floor() for options. + **kwargs: Additional keyword arguments. Returns: tuple ( int, float, tuple (int), list ( str ) ): @@ -1285,7 +1485,11 @@ def compute_wrapper( debug_path = get_debug_path(output_folder, wav_filepath, enabled=debug) # Load the spectrogram from a WAV file on disk - stft_db, waveplot, sr, bands, duration, freq_offset = load_stft(wav_filepath) + stft_db, waveplot, sr, bands, duration, freq_offset = load_stft( + wav_filepath, + noise_reduction=noise_reduction, + noise_reduction_params=noise_reduction_params, + ) # Apply a dynamic range to a fixed dB range stft_db = gain_stft(stft_db) diff --git a/tests/test_spectrogram.py b/tests/test_spectrogram.py index 69f8fe5..6b7e566 100644 --- a/tests/test_spectrogram.py +++ b/tests/test_spectrogram.py @@ -1,8 +1,208 @@ from os.path import abspath, join +import numpy as np +import pytest + def test_spectrogram_compute(): from batbot.spectrogram import compute wav_filepath = abspath(join('examples', 'example1.wav')) output_paths, metadata_path, metadata = compute(wav_filepath) + + +class TestNoiseReduction: + """Tests for noise reduction algorithms.""" + + @pytest.fixture + def synthetic_spectrogram(self): + """Create a synthetic spectrogram with noise and signal.""" + np.random.seed(42) + freq_bins, time_frames = 64, 200 + + # Create noise floor (first 20 frames are pure noise) + noise_level = -60.0 + noise = np.random.randn(freq_bins, time_frames) * 5 + noise_level + + # Add a synthetic bat chirp signal (frequency sweep) + signal = np.zeros((freq_bins, time_frames)) + for t in range(50, 150): + freq_idx = int(freq_bins * 0.8 - (t - 50) * 0.4) + if 0 <= freq_idx < freq_bins: + signal[freq_idx, t] = 20.0 # Signal 20 dB above noise + + return noise + signal + + def test_spectral_subtraction_preserves_shape(self, synthetic_spectrogram): + """Test that spectral subtraction preserves input shape.""" + from batbot.spectrogram import spectral_subtraction + + result = spectral_subtraction(synthetic_spectrogram) + assert result.shape == synthetic_spectrogram.shape + + def test_spectral_subtraction_reduces_noise(self, synthetic_spectrogram): + """Test that spectral subtraction reduces noise in noise-only regions.""" + from batbot.spectrogram import spectral_subtraction + + result = spectral_subtraction(synthetic_spectrogram, noise_frames=20) + + # Noise region (first 20 frames) should be reduced + noise_region_before = synthetic_spectrogram[:, :20].mean() + noise_region_after = result[:, :20].mean() + assert noise_region_after < noise_region_before + + def test_spectral_subtraction_with_oversubtraction(self, synthetic_spectrogram): + """Test spectral subtraction with different oversubtraction factors.""" + from batbot.spectrogram import spectral_subtraction + + result_1x = spectral_subtraction(synthetic_spectrogram, oversubtraction=1.0) + result_2x = spectral_subtraction(synthetic_spectrogram, oversubtraction=2.0) + + # Higher oversubtraction should result in lower mean values + assert result_2x.mean() < result_1x.mean() + + def test_wiener_filter_preserves_shape(self, synthetic_spectrogram): + """Test that Wiener filter preserves input shape.""" + from batbot.spectrogram import wiener_filter + + result = wiener_filter(synthetic_spectrogram) + assert result.shape == synthetic_spectrogram.shape + + def test_wiener_filter_reduces_noise(self, synthetic_spectrogram): + """Test that Wiener filter reduces noise while preserving signal.""" + from batbot.spectrogram import wiener_filter + + result = wiener_filter(synthetic_spectrogram, noise_frames=20) + + # Signal region should still have higher values than noise region + signal_region = result[:, 50:150].max() + noise_region = result[:, :20].mean() + assert signal_region > noise_region + + def test_wiener_filter_gain_floor(self, synthetic_spectrogram): + """Test that Wiener filter respects gain floor.""" + from batbot.spectrogram import wiener_filter + + result_high_floor = wiener_filter(synthetic_spectrogram, gain_floor=0.5) + result_low_floor = wiener_filter(synthetic_spectrogram, gain_floor=0.01) + + # Higher gain floor should preserve more energy + assert result_high_floor.mean() > result_low_floor.mean() + + def test_adaptive_noise_floor_preserves_shape(self, synthetic_spectrogram): + """Test that adaptive noise floor preserves input shape.""" + from batbot.spectrogram import adaptive_noise_floor + + result = adaptive_noise_floor(synthetic_spectrogram) + assert result.shape == synthetic_spectrogram.shape + + def test_adaptive_noise_floor_non_negative(self, synthetic_spectrogram): + """Test that adaptive noise floor output is non-negative.""" + from batbot.spectrogram import adaptive_noise_floor + + result = adaptive_noise_floor(synthetic_spectrogram) + assert result.min() >= 0.0 + + def test_adaptive_noise_floor_percentile_effect(self, synthetic_spectrogram): + """Test that different percentiles affect the output.""" + from batbot.spectrogram import adaptive_noise_floor + + result_low = adaptive_noise_floor(synthetic_spectrogram, percentile=5) + result_high = adaptive_noise_floor(synthetic_spectrogram, percentile=25) + + # Lower percentile = lower noise floor = more signal preserved + assert result_low.mean() > result_high.mean() + + def test_noise_reduction_none_unchanged(self, synthetic_spectrogram): + """Test that None noise_reduction leaves data unchanged.""" + from batbot.spectrogram import ( + spectral_subtraction, + wiener_filter, + adaptive_noise_floor, + ) + + # Each function should return a copy, not modify in place + original = synthetic_spectrogram.copy() + + spectral_subtraction(synthetic_spectrogram) + assert np.allclose(synthetic_spectrogram, original) + + wiener_filter(synthetic_spectrogram) + assert np.allclose(synthetic_spectrogram, original) + + adaptive_noise_floor(synthetic_spectrogram) + assert np.allclose(synthetic_spectrogram, original) + + def test_empty_spectrogram_handling(self): + """Test that noise reduction handles edge cases gracefully.""" + from batbot.spectrogram import ( + spectral_subtraction, + wiener_filter, + adaptive_noise_floor, + ) + + # Very small spectrogram + small = np.random.randn(4, 5) - 40.0 + + result1 = spectral_subtraction(small, noise_frames=2) + result2 = wiener_filter(small, noise_frames=2) + result3 = adaptive_noise_floor(small, window_frames=2) + + assert result1.shape == small.shape + assert result2.shape == small.shape + assert result3.shape == small.shape + + +class TestLoadStftNoiseReduction: + """Tests for load_stft with noise reduction options.""" + + @pytest.fixture + def wav_filepath(self): + return abspath(join('examples', 'example1.wav')) + + def test_load_stft_no_noise_reduction(self, wav_filepath): + """Test load_stft with no noise reduction (default).""" + from batbot.spectrogram import load_stft + + result = load_stft(wav_filepath, noise_reduction=None) + assert len(result) == 6 # stft_db, waveplot, sr, bands, duration, min_index + + def test_load_stft_spectral_subtraction(self, wav_filepath): + """Test load_stft with spectral subtraction.""" + from batbot.spectrogram import load_stft + + result = load_stft(wav_filepath, noise_reduction='spectral') + stft_db = result[0] + assert stft_db is not None + assert len(stft_db.shape) == 2 + + def test_load_stft_wiener_filter(self, wav_filepath): + """Test load_stft with Wiener filter.""" + from batbot.spectrogram import load_stft + + result = load_stft(wav_filepath, noise_reduction='wiener') + stft_db = result[0] + assert stft_db is not None + assert len(stft_db.shape) == 2 + + def test_load_stft_adaptive_noise_floor(self, wav_filepath): + """Test load_stft with adaptive noise floor.""" + from batbot.spectrogram import load_stft + + result = load_stft(wav_filepath, noise_reduction='adaptive') + stft_db = result[0] + assert stft_db is not None + assert len(stft_db.shape) == 2 + + def test_load_stft_with_custom_params(self, wav_filepath): + """Test load_stft with custom noise reduction parameters.""" + from batbot.spectrogram import load_stft + + params = {'noise_frames': 20, 'oversubtraction': 1.5} + result = load_stft( + wav_filepath, + noise_reduction='spectral', + noise_reduction_params=params + ) + stft_db = result[0] + assert stft_db is not None