|
| 1 | +import copy |
| 2 | +import math |
| 3 | +import torch |
| 4 | +import scipy |
| 5 | +import torchaudio |
| 6 | +import numpy as np |
| 7 | +import torch.nn.functional as F |
| 8 | +from typing import Optional, List |
| 9 | + |
| 10 | +# defaulted to the new pytorch api |
| 11 | +def _new_rfft(x: torch.Tensor): |
| 12 | + z = torch.fft.rfft(x, dim=-1) |
| 13 | + return torch.view_as_real(z) |
| 14 | + |
| 15 | +def _new_irfft(x: torch.Tensor, length: int): |
| 16 | + x = torch.view_as_complex(x) |
| 17 | + return torch.fft.irfft(x, length, dim=-1) |
| 18 | + |
| 19 | +def _compl_mul_conjugate(a: torch.Tensor, b: torch.Tensor): |
| 20 | + # changed this function to use the pytorch api |
| 21 | + return torch.view_as_real(torch.view_as_complex(a) * torch.view_as_complex(b).conj()) |
| 22 | + |
| 23 | +def unfold(input, kernel_size: int, stride: int): |
| 24 | + |
| 25 | + shape = list(input.shape) |
| 26 | + length = shape.pop(-1) |
| 27 | + |
| 28 | + n_frames = math.ceil((max(length, kernel_size) - kernel_size) / stride) + 1 |
| 29 | + tgt_length = (n_frames - 1) * stride + kernel_size |
| 30 | + |
| 31 | + padded = F.pad(input, (0, tgt_length - length)).contiguous() |
| 32 | + strides: List[int] = [] |
| 33 | + |
| 34 | + for dim in range(padded.dim()): |
| 35 | + strides.append(padded.stride(dim)) |
| 36 | + |
| 37 | + last_stride = strides.pop(-1) |
| 38 | + assert last_stride == 1, 'data should be contiguous' |
| 39 | + |
| 40 | + strides = strides + [stride, 1] |
| 41 | + return padded.as_strided(shape + [n_frames, kernel_size], strides) |
| 42 | + |
| 43 | +# convert the signal and filter to frequency domain, multiply them, then inverse FFT to get back to time-domain |
| 44 | +# faster than a sliding window over time-domain. |
| 45 | +def fft_conv1d( |
| 46 | + input: torch.Tensor, weight: torch.Tensor, |
| 47 | + bias: Optional[torch.Tensor] = None, stride: int = 1, padding: int = 0, |
| 48 | + block_ratio: float = 5): |
| 49 | + |
| 50 | + input = F.pad(input, (padding, padding)) |
| 51 | + batch, _, length = input.shape |
| 52 | + out_channels, _, kernel_size = weight.shape |
| 53 | + |
| 54 | + _rfft = _new_rfft |
| 55 | + _irfft = _new_irfft |
| 56 | + |
| 57 | + if length < kernel_size: |
| 58 | + raise RuntimeError(f"Input should be at least as large as the kernel size {kernel_size}, " |
| 59 | + f"but it is only {length} samples long.") |
| 60 | + if block_ratio < 1: |
| 61 | + raise RuntimeError("Block ratio must be greater than 1.") |
| 62 | + |
| 63 | + # We are going to process the input blocks by blocks, as for some reason it is faster |
| 64 | + # and less memory intensive (I think the culprit is `torch.einsum`. |
| 65 | + block_size: int = min(int(kernel_size * block_ratio), length) |
| 66 | + fold_stride = block_size - kernel_size + 1 |
| 67 | + |
| 68 | + # replaces to_pad |
| 69 | + weight = F.pad(weight, (0, block_size - weight.shape[-1]), mode = "constant", value = 0) |
| 70 | + weight_z = _rfft(weight) |
| 71 | + |
| 72 | + # We pad the input and get the different frames, on which |
| 73 | + frames = unfold(input, block_size, fold_stride) |
| 74 | + |
| 75 | + frames_z = _rfft(frames) |
| 76 | + out_z = _compl_mul_conjugate(frames_z, weight_z) |
| 77 | + out = _irfft(out_z, block_size) |
| 78 | + # The last bit is invalid, because FFT will do a circular convolution. |
| 79 | + out = out[..., :-kernel_size + 1] |
| 80 | + out = out.reshape(batch, out_channels, -1) |
| 81 | + out = out[..., ::stride] |
| 82 | + target_length = (length - kernel_size) // stride + 1 |
| 83 | + out = out[..., :target_length] |
| 84 | + if bias is not None: |
| 85 | + out += bias[:, None] |
| 86 | + return out |
| 87 | + |
| 88 | +class IIRfilter(object): |
| 89 | + |
| 90 | + def __init__(self, G, Q, fc, rate, filter_type, passband_gain=1.0): |
| 91 | + self.G = G |
| 92 | + self.Q = Q |
| 93 | + self.fc = fc |
| 94 | + self.rate = rate |
| 95 | + self.filter_type = filter_type |
| 96 | + self.passband_gain = passband_gain |
| 97 | + |
| 98 | + def generate_coefficients(self): |
| 99 | + |
| 100 | + A = 10**(self.G/40.0) |
| 101 | + w0 = 2.0 * np.pi * (self.fc / self.rate) |
| 102 | + alpha = np.sin(w0) / (2.0 * self.Q) |
| 103 | + |
| 104 | + if self.filter_type == 'high_shelf': |
| 105 | + b0 = A * ( (A+1) + (A-1) * np.cos(w0) + 2 * np.sqrt(A) * alpha ) |
| 106 | + b1 = -2 * A * ( (A-1) + (A+1) * np.cos(w0) ) |
| 107 | + b2 = A * ( (A+1) + (A-1) * np.cos(w0) - 2 * np.sqrt(A) * alpha ) |
| 108 | + a0 = (A+1) - (A-1) * np.cos(w0) + 2 * np.sqrt(A) * alpha |
| 109 | + a1 = 2 * ( (A-1) - (A+1) * np.cos(w0) ) |
| 110 | + a2 = (A+1) - (A-1) * np.cos(w0) - 2 * np.sqrt(A) * alpha |
| 111 | + |
| 112 | + elif self.filter_type == 'high_pass': |
| 113 | + b0 = (1 + np.cos(w0))/2 |
| 114 | + b1 = -(1 + np.cos(w0)) |
| 115 | + b2 = (1 + np.cos(w0))/2 |
| 116 | + a0 = 1 + alpha |
| 117 | + a1 = -2 * np.cos(w0) |
| 118 | + a2 = 1 - alpha |
| 119 | + |
| 120 | + return np.array([b0, b1, b2])/a0, np.array([a0, a1, a2])/a0 |
| 121 | + |
| 122 | + def apply_filter(self, data): |
| 123 | + return self.passband_gain * scipy.signal.lfilter(self.b, self.a, data) |
| 124 | + |
| 125 | + @property |
| 126 | + def b_and_a(self): |
| 127 | + return self.generate_coefficients() |
| 128 | + |
| 129 | +class Meter(torch.nn.Module): |
| 130 | + |
| 131 | + def __init__( |
| 132 | + self, |
| 133 | + rate: int, |
| 134 | + filter_class: str = "K-weighting", |
| 135 | + block_size: float = 0.400, |
| 136 | + zeros: int = 512, |
| 137 | + use_fir: bool = False, |
| 138 | + ): |
| 139 | + super().__init__() |
| 140 | + |
| 141 | + self.rate = rate |
| 142 | + self.filter_class = filter_class |
| 143 | + self.block_size = block_size |
| 144 | + self.use_fir = use_fir |
| 145 | + |
| 146 | + G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41])) |
| 147 | + self.register_buffer("G", G) |
| 148 | + |
| 149 | + self._filters = {} |
| 150 | + self._filters['high_shelf'] = IIRfilter(4.0, 1/np.sqrt(2), 1500.0, self.rate, 'high_shelf') |
| 151 | + self._filters['high_pass'] = IIRfilter(0.0, 0.5, 38.0, self.rate, 'high_pass') |
| 152 | + |
| 153 | + # Compute impulse responses so that filtering is fast via |
| 154 | + # a convolution at runtime, on GPU, unlike lfilter. |
| 155 | + impulse = np.zeros((zeros,)) |
| 156 | + impulse[..., 0] = 1.0 |
| 157 | + |
| 158 | + firs = np.zeros((len(self._filters), 1, zeros)) |
| 159 | + passband_gain = torch.tensor([filter.passband_gain for filter in self._filters.values()]) |
| 160 | + |
| 161 | + for i, (_, filter_stage) in enumerate(self._filters.items()): |
| 162 | + b, a = filter_stage.b_and_a |
| 163 | + firs[i] = scipy.signal.lfilter(b, a, impulse) |
| 164 | + |
| 165 | + firs = torch.from_numpy(firs[..., ::-1].copy()).float() |
| 166 | + |
| 167 | + self.register_buffer("firs", firs) |
| 168 | + self.register_buffer("passband_gain", passband_gain) |
| 169 | + |
| 170 | + def apply_filter_gpu(self, data: torch.Tensor): |
| 171 | + |
| 172 | + # Data is of shape (nb, nch, nt) |
| 173 | + # Reshape to (nb*nch, 1, nt) |
| 174 | + nb, nt, nch = data.shape |
| 175 | + data = data.permute(0, 2, 1) |
| 176 | + data = data.reshape(nb * nch, 1, nt) |
| 177 | + |
| 178 | + # Apply padding |
| 179 | + pad_length = self.firs.shape[-1] |
| 180 | + |
| 181 | + # Apply filtering in sequence |
| 182 | + for i in range(self.firs.shape[0]): |
| 183 | + data = F.pad(data, (pad_length, pad_length)) |
| 184 | + data = fft_conv1d(data, self.firs[i, None, ...]) |
| 185 | + data = self.passband_gain[i] * data |
| 186 | + data = data[..., 1 : nt + 1] |
| 187 | + |
| 188 | + data = data.permute(0, 2, 1) |
| 189 | + data = data[:, :nt, :] |
| 190 | + return data |
| 191 | + |
| 192 | + def apply_filter_cpu(self, data: torch.Tensor): |
| 193 | + for _, filter_stage in self._filters.items(): |
| 194 | + passband_gain = filter_stage.passband_gain |
| 195 | + b, a = filter_stage.b_and_a |
| 196 | + |
| 197 | + a_coeffs = torch.from_numpy(a).float().to(data.device) |
| 198 | + b_coeffs = torch.from_numpy(b).float().to(data.device) |
| 199 | + |
| 200 | + _data = data.permute(0, 2, 1) |
| 201 | + filtered = torchaudio.functional.lfilter( |
| 202 | + _data, a_coeffs, b_coeffs, clamp=False |
| 203 | + ) |
| 204 | + data = passband_gain * filtered.permute(0, 2, 1) |
| 205 | + return data |
| 206 | + |
| 207 | + def apply_filter(self, data: torch.Tensor): |
| 208 | + if data.is_cuda or self.use_fir: |
| 209 | + data = self.apply_filter_gpu(data) |
| 210 | + else: |
| 211 | + data = self.apply_filter_cpu(data) |
| 212 | + return data |
| 213 | + |
| 214 | + def forward(self, data: torch.Tensor): |
| 215 | + return self.integrated_loudness(data) |
| 216 | + |
| 217 | + def _unfold(self, input_data): |
| 218 | + T_g = self.block_size |
| 219 | + overlap = 0.75 # overlap of 75% of the block duration |
| 220 | + step = 1.0 - overlap # step size by percentage |
| 221 | + |
| 222 | + kernel_size = int(T_g * self.rate) |
| 223 | + stride = int(T_g * self.rate * step) |
| 224 | + unfolded = unfold(input_data.permute(0, 2, 1), kernel_size, stride) |
| 225 | + unfolded = unfolded.transpose(-1, -2) |
| 226 | + |
| 227 | + return unfolded |
| 228 | + |
| 229 | + def integrated_loudness(self, data: torch.Tensor): |
| 230 | + |
| 231 | + if not torch.is_tensor(data): |
| 232 | + data = torch.from_numpy(data).float() |
| 233 | + else: |
| 234 | + data = data.float() |
| 235 | + |
| 236 | + input_data = copy.copy(data) |
| 237 | + # Data always has a batch and channel dimension. |
| 238 | + # Is of shape (nb, nt, nch) |
| 239 | + if input_data.ndim < 2: |
| 240 | + input_data = input_data.unsqueeze(-1) |
| 241 | + if input_data.ndim < 3: |
| 242 | + input_data = input_data.unsqueeze(0) |
| 243 | + |
| 244 | + nb, _, nch = input_data.shape |
| 245 | + |
| 246 | + # Apply frequency weighting filters - account |
| 247 | + # for the acoustic respose of the head and auditory system |
| 248 | + input_data = self.apply_filter(input_data) |
| 249 | + |
| 250 | + G = self.G # channel gains |
| 251 | + T_g = self.block_size # 400 ms gating block standard |
| 252 | + Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold |
| 253 | + |
| 254 | + unfolded = self._unfold(input_data) |
| 255 | + |
| 256 | + z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2) |
| 257 | + l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True)) |
| 258 | + l = l.expand_as(z) |
| 259 | + |
| 260 | + # find gating block indices above absolute threshold |
| 261 | + z_avg_gated = z |
| 262 | + z_avg_gated[l <= Gamma_a] = 0 |
| 263 | + masked = l > Gamma_a |
| 264 | + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) |
| 265 | + |
| 266 | + # calculate the relative threshold value (see eq. 6) |
| 267 | + Gamma_r = ( |
| 268 | + -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0 |
| 269 | + ) |
| 270 | + Gamma_r = Gamma_r[:, None, None] |
| 271 | + Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1]) |
| 272 | + |
| 273 | + # find gating block indices above relative and absolute thresholds (end of eq. 7) |
| 274 | + z_avg_gated = z |
| 275 | + z_avg_gated[l <= Gamma_a] = 0 |
| 276 | + z_avg_gated[l <= Gamma_r] = 0 |
| 277 | + masked = (l > Gamma_a) * (l > Gamma_r) |
| 278 | + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) |
| 279 | + |
| 280 | + # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version) |
| 281 | + # z_avg_gated = torch.nan_to_num(z_avg_gated) |
| 282 | + z_avg_gated = torch.where( |
| 283 | + z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated |
| 284 | + ) |
| 285 | + z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max) |
| 286 | + z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min) |
| 287 | + |
| 288 | + LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1)) |
| 289 | + return LUFS.float() |
| 290 | + |
| 291 | + |
| 292 | +def loudness( |
| 293 | + audio_data, sample_rate: int, target_loudness: int, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs |
| 294 | +): |
| 295 | + MIN_LOUDNESS = -70 |
| 296 | + device = audio_data.device |
| 297 | + |
| 298 | + original_length = audio_data.shape[-1] |
| 299 | + signal_duration = original_length / sample_rate |
| 300 | + |
| 301 | + # Pad if too short |
| 302 | + if signal_duration < 0.5: |
| 303 | + pad_len = int((0.5 - signal_duration) * sample_rate) |
| 304 | + audio_data = torch.nn.functional.pad(audio_data, (0, pad_len), mode="constant", value=0) |
| 305 | + |
| 306 | + # create BS.1770 meter |
| 307 | + meter = Meter( |
| 308 | + sample_rate, filter_class=filter_class, block_size=block_size, **kwargs |
| 309 | + ) |
| 310 | + meter = meter.to(audio_data.device) |
| 311 | + # measure loudness |
| 312 | + loudness = meter.integrated_loudness(audio_data.permute(0, 2, 1)) |
| 313 | + audio_data = audio_data[..., :original_length] |
| 314 | + min_loudness = ( |
| 315 | + torch.ones_like(loudness, device=loudness.device) * MIN_LOUDNESS |
| 316 | + ) |
| 317 | + _loudness = torch.maximum(loudness, min_loudness) |
| 318 | + |
| 319 | + _loudness = _loudness.to(device) |
| 320 | + |
| 321 | + delta_loudness = target_loudness - _loudness |
| 322 | + gain = torch.pow(torch.tensor(10.0, device=device, dtype=audio_data.dtype), delta_loudness / 20.0) |
| 323 | + |
| 324 | + output = gain * audio_data |
| 325 | + |
| 326 | + if torch.max(torch.abs(output)) >= 1.0: |
| 327 | + import warnings |
| 328 | + warnings.warn("Possible clipped samples in output.") |
| 329 | + |
| 330 | + return output |
0 commit comments