Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/unittest-linux/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ fi
export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_inflect=true
export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_pytorch_lightning=true
cd test
pytest torchaudio_unittest -k "not torchscript and not fairseq and not demucs ${PYTEST_K_EXTRA}"
pytest torchaudio_unittest -k "not fairseq and not demucs ${PYTEST_K_EXTRA}"
)
2 changes: 1 addition & 1 deletion .github/scripts/unittest-windows/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ python -m torch.utils.collect_env
env | grep TORCHAUDIO || true

cd test
pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not torchscript and not fairseq and not demucs and not librosa"
pytest --continue-on-collection-errors --cov=torchaudio --junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml -v --durations 20 torchaudio_unittest -k "not fairseq and not demucs and not librosa"
coverage html
25 changes: 21 additions & 4 deletions src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
b_coeff_flipped = b_coeffs.flip(1).contiguous()
padded_waveform = F.pad(waveform, (n_order - 1, 0))
output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
ctx.save_for_backward(waveform, b_coeffs, output)
if not torch.jit.is_scripting():
ctx.save_for_backward(waveform, b_coeffs, output)
return output

@staticmethod
Expand All @@ -955,6 +956,7 @@ def backward(ctx, dy):
n_batch = x.size(0)
n_channel = x.size(1)
n_order = b_coeffs.size(1)

db = (
F.conv1d(
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
Expand All @@ -970,6 +972,13 @@ def backward(ctx, dy):
dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None
return (dx, db)

@staticmethod
def ts_apply(waveform, b_coeffs):
if torch.jit.is_scripting():
return DifferentiableFIR.forward(torch.empty(0), waveform, b_coeffs)
else:
return DifferentiableFIR.apply(waveform, b_coeffs)


class DifferentiableIIR(torch.autograd.Function):
@staticmethod
Expand All @@ -984,7 +993,8 @@ def forward(ctx, waveform, a_coeffs_normalized):
)
_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
output = padded_output_waveform[:, :, n_order - 1 :]
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
if not torch.jit.is_scripting():
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
return output

@staticmethod
Expand All @@ -1006,10 +1016,17 @@ def backward(ctx, dy):
)
return (dx, da)

@staticmethod
def ts_apply(waveform, a_coeffs_normalized):
if torch.jit.is_scripting():
return DifferentiableIIR.forward(torch.empty(0), waveform, a_coeffs_normalized)
else:
return DifferentiableIIR.apply(waveform, a_coeffs_normalized)


def _lfilter(waveform, a_coeffs, b_coeffs):
filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
filtered_waveform = DifferentiableFIR.ts_apply(waveform, b_coeffs / a_coeffs[:, 0:1])
return DifferentiableIIR.ts_apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])


def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
Expand Down
18 changes: 15 additions & 3 deletions src/torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,8 @@ def mask_along_axis_iid(

if axis not in [dim - 2, dim - 1]:
raise ValueError(
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
"Only Frequency and Time masking are supported"
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
)

if not 0.0 <= p <= 1.0:
Expand Down Expand Up @@ -920,7 +921,8 @@ def mask_along_axis(

if axis not in [dim - 2, dim - 1]:
raise ValueError(
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
"Only Frequency and Time masking are supported"
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
)

if not 0.0 <= p <= 1.0:
Expand Down Expand Up @@ -1732,6 +1734,16 @@ def backward(ctx, dy):
result = grad * grad_out
return (result, None, None, None, None, None, None, None)

@staticmethod
def ts_apply(logits, targets, logit_lengths, target_lengths, blank: int, clamp: float, fused_log_softmax: bool):
if torch.jit.is_scripting():
output, saved = torch.ops.torchaudio.rnnt_loss_forward(
logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax
)
return output
else:
return RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)


def _rnnt_loss(
logits: Tensor,
Expand Down Expand Up @@ -1775,7 +1787,7 @@ def _rnnt_loss(
if blank < 0: # reinterpret blank index if blank < 0.
blank = logits.shape[-1] + blank

costs = RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)
costs = RnntLoss.ts_apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)

if reduction == "mean":
return costs.mean()
Expand Down
3 changes: 2 additions & 1 deletion src/torchaudio/transforms/_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,8 @@ def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0
specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p
)
else:
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p)
mask_value_ = float(mask_value) if isinstance(mask_value, Tensor) else mask_value
return F.mask_along_axis(specgram, self.mask_param, mask_value_, self.axis + specgram.dim() - 3, p=self.p)


class FrequencyMasking(_AxisMasking):
Expand Down
Loading