diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 4070141958..41acfeae66 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1578,7 +1578,17 @@ def loudness(waveform: Tensor, sample_rate: int): gated_blocks = loudness > gamma_abs gated_blocks = gated_blocks.unsqueeze(-2) - energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + # Compute numerator and denominator + sum_gated_energy = torch.sum(gated_blocks * energy, dim=-1) + count_gated_blocks = torch.count_nonzero(gated_blocks, dim=-1) + + # Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0 + energy_filtered = torch.where( + count_gated_blocks > 0, + sum_gated_energy / count_gated_blocks, + torch.tensor(0.0, dtype=sum_gated_energy.dtype, device=sum_gated_energy.device) + ) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10 @@ -1586,7 +1596,17 @@ def loudness(waveform: Tensor, sample_rate: int): gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1)) gated_blocks = gated_blocks.unsqueeze(-2) - energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + # Compute numerator and denominator + sum_gated_energy = torch.sum(gated_blocks * energy, dim=-1) + count_gated_blocks = torch.count_nonzero(gated_blocks, dim=-1) + + # Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0 + energy_filtered = torch.where( + count_gated_blocks > 0, + sum_gated_energy / count_gated_blocks, + torch.tensor(0.0, dtype=sum_gated_energy.dtype, device=sum_gated_energy.device) + ) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) LKFS = kweight_bias + 10 * torch.log10(energy_weighted) return LKFS