From 35c7f31eaa3a0ee63e0261e4a46db52985483ef5 Mon Sep 17 00:00:00 2001 From: ChenMengqi <160377674+Az-CMQ@users.noreply.github.com> Date: Sun, 14 Sep 2025 12:38:29 +0800 Subject: [PATCH] [Fix] loudness: prevent NaN when all blocks are below absolute threshold --- src/torchaudio/functional/functional.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 4070141958..41acfeae66 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1578,7 +1578,17 @@ def loudness(waveform: Tensor, sample_rate: int): gated_blocks = loudness > gamma_abs gated_blocks = gated_blocks.unsqueeze(-2) - energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + # Compute numerator and denominator + sum_gated_energy = torch.sum(gated_blocks * energy, dim=-1) + count_gated_blocks = torch.count_nonzero(gated_blocks, dim=-1) + + # Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0 + energy_filtered = torch.where( + count_gated_blocks > 0, + sum_gated_energy / count_gated_blocks, + torch.tensor(0.0, dtype=sum_gated_energy.dtype, device=sum_gated_energy.device) + ) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10 @@ -1586,7 +1596,17 @@ def loudness(waveform: Tensor, sample_rate: int): gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1)) gated_blocks = gated_blocks.unsqueeze(-2) - energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1) + # Compute numerator and denominator + sum_gated_energy = torch.sum(gated_blocks * energy, dim=-1) + count_gated_blocks = torch.count_nonzero(gated_blocks, dim=-1) + + # Use torch.where to avoid division by zero: if count is 0, set energy_filtered to 0 + energy_filtered = torch.where( + count_gated_blocks > 0, + sum_gated_energy / count_gated_blocks, + torch.tensor(0.0, dtype=sum_gated_energy.dtype, device=sum_gated_energy.device) + ) + energy_weighted = torch.sum(g * energy_filtered, dim=-1) LKFS = kweight_bias + 10 * torch.log10(energy_weighted) return LKFS