diff --git a/.github/scripts/unittest-linux/run_test.sh b/.github/scripts/unittest-linux/run_test.sh index 6cc935b444..5b235c772c 100755 --- a/.github/scripts/unittest-linux/run_test.sh +++ b/.github/scripts/unittest-linux/run_test.sh @@ -34,5 +34,5 @@ fi export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_inflect=true export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_pytorch_lightning=true cd test - pytest torchaudio_unittest -k "not torchscript and not fairseq and not demucs ${PYTEST_K_EXTRA}" + pytest torchaudio_unittest -k "not fairseq and not demucs ${PYTEST_K_EXTRA}" ) diff --git a/.github/scripts/unittest-windows/run_test.sh b/.github/scripts/unittest-windows/run_test.sh index 25d8e14196..9f6ffb1375 100644 --- a/.github/scripts/unittest-windows/run_test.sh +++ b/.github/scripts/unittest-windows/run_test.sh @@ -12,5 +12,5 @@ python -m torch.utils.collect_env env | grep TORCHAUDIO || true cd test -pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not torchscript and not fairseq and not demucs and not librosa" +pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not fairseq and not demucs and not librosa" coverage html diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 1a7aa3e37e..8f18b35de2 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs): b_coeff_flipped = b_coeffs.flip(1).contiguous() padded_waveform = F.pad(waveform, (n_order - 1, 0)) output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel) - ctx.save_for_backward(waveform, b_coeffs, output) + if not torch.jit.is_scripting(): + ctx.save_for_backward(waveform, b_coeffs, output) return output @staticmethod @@ -955,6 +956,7 @@ def backward(ctx, dy): n_batch = x.size(0) n_channel = x.size(1) n_order = b_coeffs.size(1) + db = ( F.conv1d( F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1), @@ -970,6 +972,13 @@ def backward(ctx, dy): dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None return (dx, db) + @staticmethod + def ts_apply(waveform, b_coeffs): + if torch.jit.is_scripting(): + return DifferentiableFIR.forward(torch.empty(0), waveform, b_coeffs) + else: + return DifferentiableFIR.apply(waveform, b_coeffs) + class DifferentiableIIR(torch.autograd.Function): @staticmethod @@ -984,7 +993,8 @@ def forward(ctx, waveform, a_coeffs_normalized): ) _lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform) output = padded_output_waveform[:, :, n_order - 1 :] - ctx.save_for_backward(waveform, a_coeffs_normalized, output) + if not torch.jit.is_scripting(): + ctx.save_for_backward(waveform, a_coeffs_normalized, output) return output @staticmethod @@ -1006,10 +1016,17 @@ def backward(ctx, dy): ) return (dx, da) + @staticmethod + def ts_apply(waveform, a_coeffs_normalized): + if torch.jit.is_scripting(): + return DifferentiableIIR.forward(torch.empty(0), waveform, a_coeffs_normalized) + else: + return DifferentiableIIR.apply(waveform, a_coeffs_normalized) + def _lfilter(waveform, a_coeffs, b_coeffs): - filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1]) - return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1]) + filtered_waveform = DifferentiableFIR.ts_apply(waveform, b_coeffs / a_coeffs[:, 0:1]) + return DifferentiableIIR.ts_apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1]) def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor: diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 4070141958..884beec1f7 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -848,7 +848,8 @@ def mask_along_axis_iid( if axis not in [dim - 2, dim - 1]: raise ValueError( - f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)." + "Only Frequency and Time masking are supported" + f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)." ) if not 0.0 <= p <= 1.0: @@ -920,7 +921,8 @@ def mask_along_axis( if axis not in [dim - 2, dim - 1]: raise ValueError( - f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)." + "Only Frequency and Time masking are supported" + f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)." ) if not 0.0 <= p <= 1.0: @@ -1732,6 +1734,16 @@ def backward(ctx, dy): result = grad * grad_out return (result, None, None, None, None, None, None, None) + @staticmethod + def ts_apply(logits, targets, logit_lengths, target_lengths, blank: int, clamp: float, fused_log_softmax: bool): + if torch.jit.is_scripting(): + output, saved = torch.ops.torchaudio.rnnt_loss_forward( + logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax + ) + return output + else: + return RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) + def _rnnt_loss( logits: Tensor, @@ -1775,7 +1787,7 @@ def _rnnt_loss( if blank < 0: # reinterpret blank index if blank < 0. blank = logits.shape[-1] + blank - costs = RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) + costs = RnntLoss.ts_apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax) if reduction == "mean": return costs.mean() diff --git a/src/torchaudio/transforms/_transforms.py b/src/torchaudio/transforms/_transforms.py index 08d2dcef11..7eb50da3f8 100644 --- a/src/torchaudio/transforms/_transforms.py +++ b/src/torchaudio/transforms/_transforms.py @@ -1202,7 +1202,8 @@ def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0 specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p ) else: - return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p) + mask_value_ = float(mask_value) if isinstance(mask_value, Tensor) else mask_value + return F.mask_along_axis(specgram, self.mask_param, mask_value_, self.axis + specgram.dim() - 3, p=self.p) class FrequencyMasking(_AxisMasking):