Skip to content

Commit 990e8be

Browse files
committed
Fix torchscript related test failures. Fix flake8.
1 parent 1eba300 commit 990e8be

File tree

4 files changed

+82
-27
lines changed

4 files changed

+82
-27
lines changed

.github/scripts/unittest-linux/run_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,5 @@ fi
3434
export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_inflect=true
3535
export TORCHAUDIO_TEST_ALLOW_SKIP_IF_NO_MOD_pytorch_lightning=true
3636
cd test
37-
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt) and not torchscript_consistency"
37+
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not ffmpeg and not fairseq and not hdemucs"
3838
)

src/torchaudio/functional/filtering.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
946946
b_coeff_flipped = b_coeffs.flip(1).contiguous()
947947
padded_waveform = F.pad(waveform, (n_order - 1, 0))
948948
output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
949-
ctx.save_for_backward(waveform, b_coeffs, output)
949+
if not torch.jit.is_scripting():
950+
ctx.save_for_backward(waveform, b_coeffs, output)
950951
return output
951952

952953
@staticmethod
@@ -956,32 +957,41 @@ def backward(ctx, dy):
956957
n_channel = x.size(1)
957958
n_order = b_coeffs.size(1)
958959
db = F.conv1d(
959-
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
960-
dy.view(n_batch * n_channel, 1, -1),
961-
groups=n_batch * n_channel
962-
).view(
963-
n_batch, n_channel, -1
964-
).sum(0).flip(1) if b_coeffs.requires_grad else None
960+
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
961+
dy.view(n_batch * n_channel, 1, -1),
962+
groups=n_batch * n_channel
963+
).view(
964+
n_batch, n_channel, -1
965+
).sum(0).flip(1) if b_coeffs.requires_grad else None
965966
dx = F.conv1d(
966-
F.pad(dy, (0, n_order - 1)),
967-
b_coeffs.unsqueeze(1),
968-
groups=n_channel
969-
) if x.requires_grad else None
967+
F.pad(dy, (0, n_order - 1)),
968+
b_coeffs.unsqueeze(1),
969+
groups=n_channel
970+
) if x.requires_grad else None
970971
return (dx, db)
971972

973+
@staticmethod
974+
def ts_apply(waveform, b_coeffs):
975+
if torch.jit.is_scripting():
976+
return DifferentiableFIR.forward(torch.empty(0), waveform, b_coeffs)
977+
else:
978+
return DifferentiableFIR.apply(waveform, b_coeffs)
979+
980+
972981
class DifferentiableIIR(torch.autograd.Function):
973982
@staticmethod
974983
def forward(ctx, waveform, a_coeffs_normalized):
975984
n_batch, n_channel, n_sample = waveform.shape
976985
n_order = a_coeffs_normalized.size(1)
977986
n_sample_padded = n_sample + n_order - 1
978987

979-
a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
988+
a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous()
980989
padded_output_waveform = torch.zeros(n_batch, n_channel, n_sample_padded,
981-
device=waveform.device, dtype=waveform.dtype)
990+
device=waveform.device, dtype=waveform.dtype)
982991
_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
983-
output = padded_output_waveform[:,:,n_order - 1:]
984-
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
992+
output = padded_output_waveform[:, :, n_order - 1:]
993+
if not torch.jit.is_scripting():
994+
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
985995
return output
986996

987997
@staticmethod
@@ -992,15 +1002,23 @@ def backward(ctx, dy):
9921002
tmp = DifferentiableIIR.apply(dy.flip(2).contiguous(), a_coeffs_normalized).flip(2)
9931003
dx = tmp if x.requires_grad else None
9941004
da = -(tmp.transpose(0, 1).reshape(n_channel, 1, -1) @
995-
F.pad(y, (n_order - 1, 0)).unfold(2, n_order, 1).transpose(0,1)
996-
.reshape(n_channel, -1, n_order)
997-
).squeeze(1).flip(1) if a_coeffs_normalized.requires_grad else None
1005+
F.pad(y, (n_order - 1, 0)).unfold(2, n_order, 1).transpose(0, 1)
1006+
.reshape(n_channel, -1, n_order)
1007+
).squeeze(1).flip(1) if a_coeffs_normalized.requires_grad else None
9981008
return (dx, da)
9991009

1010+
@staticmethod
1011+
def ts_apply(waveform, a_coeffs_normalized):
1012+
if torch.jit.is_scripting():
1013+
return DifferentiableIIR.forward(torch.empty(0), waveform, a_coeffs_normalized)
1014+
else:
1015+
return DifferentiableIIR.apply(waveform, a_coeffs_normalized)
1016+
1017+
10001018
def _lfilter(waveform, a_coeffs, b_coeffs):
1001-
n_order = b_coeffs.size(1)
1002-
filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1003-
return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
1019+
filtered_waveform = DifferentiableFIR.ts_apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1020+
return DifferentiableIIR.ts_apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
1021+
10041022

10051023
def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
10061024
r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
@@ -1071,6 +1089,7 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =
10711089

10721090
return output
10731091

1092+
10741093
def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
10751094
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
10761095

src/torchaudio/functional/functional.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,8 @@ def mask_along_axis_iid(
851851

852852
if axis not in [dim - 2, dim - 1]:
853853
raise ValueError(
854-
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
854+
"Only Frequency and Time masking are supported"
855+
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
855856
)
856857

857858
if not 0.0 <= p <= 1.0:
@@ -923,7 +924,8 @@ def mask_along_axis(
923924

924925
if axis not in [dim - 2, dim - 1]:
925926
raise ValueError(
926-
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
927+
"Only Frequency and Time masking are supported"
928+
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
927929
)
928930

929931
if not 0.0 <= p <= 1.0:
@@ -1765,6 +1767,7 @@ def _fix_waveform_shape(
17651767
waveform_shift = waveform_shift.view(shape[:-1] + waveform_shift.shape[-1:])
17661768
return waveform_shift
17671769

1770+
17681771
class RnntLoss(torch.autograd.Function):
17691772
@staticmethod
17701773
def forward(ctx, *args):
@@ -1776,9 +1779,39 @@ def forward(ctx, *args):
17761779
def backward(ctx, dy):
17771780
grad = ctx.saved_tensors[0]
17781781
grad_out = dy.view((-1, 1, 1, 1))
1779-
result = grad * grad_out;
1782+
result = grad * grad_out
17801783
return (result, None, None, None, None, None, None, None)
17811784

1785+
@staticmethod
1786+
def ts_apply(
1787+
logits,
1788+
targets,
1789+
logit_lengths,
1790+
target_lengths,
1791+
blank: int,
1792+
clamp: float,
1793+
fused_log_softmax: bool):
1794+
if torch.jit.is_scripting():
1795+
output, saved = torch.ops.torchaudio.rnnt_loss_forward(
1796+
logits,
1797+
targets,
1798+
logit_lengths,
1799+
target_lengths,
1800+
blank,
1801+
clamp,
1802+
fused_log_softmax)
1803+
return output
1804+
else:
1805+
return RnntLoss.apply(
1806+
logits,
1807+
targets,
1808+
logit_lengths,
1809+
target_lengths,
1810+
blank,
1811+
clamp,
1812+
fused_log_softmax)
1813+
1814+
17821815
def _rnnt_loss(
17831816
logits: Tensor,
17841817
targets: Tensor,
@@ -1821,7 +1854,7 @@ def _rnnt_loss(
18211854
if blank < 0: # reinterpret blank index if blank < 0.
18221855
blank = logits.shape[-1] + blank
18231856

1824-
costs = RnntLoss.apply(
1857+
costs = RnntLoss.ts_apply(
18251858
logits,
18261859
targets,
18271860
logit_lengths,
@@ -1883,10 +1916,12 @@ def psd(
18831916
psd = psd.sum(dim=-3)
18841917
return psd
18851918

1919+
18861920
# Expose both deprecated wrapper as well as original because torchscript breaks on
18871921
# wrapped functions.
18881922
rnnt_loss = dropping_support(_rnnt_loss)
18891923

1924+
18901925
def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
18911926
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
18921927

src/torchaudio/transforms/_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,8 @@ def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0
12021202
specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p
12031203
)
12041204
else:
1205-
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p)
1205+
mask_value_ = float(mask_value) if isinstance(mask_value, Tensor) else mask_value
1206+
return F.mask_along_axis(specgram, self.mask_param, mask_value_, self.axis + specgram.dim() - 3, p=self.p)
12061207

12071208

12081209
class FrequencyMasking(_AxisMasking):

0 commit comments

Comments
 (0)