Skip to content

Commit 81bd50f

Browse files
committed
Fix torchscript related test failures.
1 parent c1afcb9 commit 81bd50f

File tree

3 files changed

+69
-21
lines changed

3 files changed

+69
-21
lines changed

src/torchaudio/functional/filtering.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,8 @@ def forward(ctx, waveform, b_coeffs):
946946
b_coeff_flipped = b_coeffs.flip(1).contiguous()
947947
padded_waveform = F.pad(waveform, (n_order - 1, 0))
948948
output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
949-
ctx.save_for_backward(waveform, b_coeffs, output)
949+
if not torch.jit.is_scripting():
950+
ctx.save_for_backward(waveform, b_coeffs, output)
950951
return output
951952

952953
@staticmethod
@@ -955,21 +956,28 @@ def backward(ctx, dy):
955956
n_batch = x.size(0)
956957
n_channel = x.size(1)
957958
n_order = b_coeffs.size(1)
958-
db = (
959-
F.conv1d(
960-
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
961-
dy.view(n_batch * n_channel, 1, -1),
962-
groups=n_batch * n_channel,
963-
)
964-
.view(n_batch, n_channel, -1)
965-
.sum(0)
966-
.flip(1)
967-
if b_coeffs.requires_grad
968-
else None
969-
)
970-
dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None
959+
960+
db = F.conv1d(
961+
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
962+
dy.view(n_batch * n_channel, 1, -1),
963+
groups=n_batch * n_channel
964+
).view(
965+
n_batch, n_channel, -1
966+
).sum(0).flip(1) if b_coeffs.requires_grad else None
967+
dx = F.conv1d(
968+
F.pad(dy, (0, n_order - 1)),
969+
b_coeffs.unsqueeze(1),
970+
groups=n_channel
971+
) if x.requires_grad else None
971972
return (dx, db)
972973

974+
@staticmethod
975+
def ts_apply(waveform, b_coeffs):
976+
if torch.jit.is_scripting():
977+
return DifferentiableFIR.forward(torch.empty(0), waveform, b_coeffs)
978+
else:
979+
return DifferentiableFIR.apply(waveform, b_coeffs)
980+
973981

974982
class DifferentiableIIR(torch.autograd.Function):
975983
@staticmethod
@@ -984,7 +992,8 @@ def forward(ctx, waveform, a_coeffs_normalized):
984992
)
985993
_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
986994
output = padded_output_waveform[:, :, n_order - 1 :]
987-
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
995+
if not torch.jit.is_scripting():
996+
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
988997
return output
989998

990999
@staticmethod
@@ -1006,10 +1015,17 @@ def backward(ctx, dy):
10061015
)
10071016
return (dx, da)
10081017

1018+
@staticmethod
1019+
def ts_apply(waveform, a_coeffs_normalized):
1020+
if torch.jit.is_scripting():
1021+
return DifferentiableIIR.forward(torch.empty(0), waveform, a_coeffs_normalized)
1022+
else:
1023+
return DifferentiableIIR.apply(waveform, a_coeffs_normalized)
1024+
10091025

10101026
def _lfilter(waveform, a_coeffs, b_coeffs):
1011-
filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1012-
return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
1027+
filtered_waveform = DifferentiableFIR.ts_apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1028+
return DifferentiableIIR.ts_apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
10131029

10141030

10151031
def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:

src/torchaudio/functional/functional.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,8 @@ def mask_along_axis_iid(
848848

849849
if axis not in [dim - 2, dim - 1]:
850850
raise ValueError(
851-
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
851+
"Only Frequency and Time masking are supported"
852+
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
852853
)
853854

854855
if not 0.0 <= p <= 1.0:
@@ -920,7 +921,8 @@ def mask_along_axis(
920921

921922
if axis not in [dim - 2, dim - 1]:
922923
raise ValueError(
923-
f"Only Frequency and Time masking are supported (axis {dim-2} and axis {dim-1} supported; {axis} given)."
924+
"Only Frequency and Time masking are supported"
925+
f" (axis {dim - 2} and axis {dim - 1} supported; {axis} given)."
924926
)
925927

926928
if not 0.0 <= p <= 1.0:
@@ -1732,6 +1734,35 @@ def backward(ctx, dy):
17321734
result = grad * grad_out
17331735
return (result, None, None, None, None, None, None, None)
17341736

1737+
@staticmethod
1738+
def ts_apply(
1739+
logits,
1740+
targets,
1741+
logit_lengths,
1742+
target_lengths,
1743+
blank: int,
1744+
clamp: float,
1745+
fused_log_softmax: bool):
1746+
if torch.jit.is_scripting():
1747+
output, saved = torch.ops.torchaudio.rnnt_loss_forward(
1748+
logits,
1749+
targets,
1750+
logit_lengths,
1751+
target_lengths,
1752+
blank,
1753+
clamp,
1754+
fused_log_softmax)
1755+
return output
1756+
else:
1757+
return RnntLoss.apply(
1758+
logits,
1759+
targets,
1760+
logit_lengths,
1761+
target_lengths,
1762+
blank,
1763+
clamp,
1764+
fused_log_softmax)
1765+
17351766

17361767
def _rnnt_loss(
17371768
logits: Tensor,
@@ -1775,7 +1806,7 @@ def _rnnt_loss(
17751806
if blank < 0: # reinterpret blank index if blank < 0.
17761807
blank = logits.shape[-1] + blank
17771808

1778-
costs = RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)
1809+
costs = RnntLoss.ts_apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)
17791810

17801811
if reduction == "mean":
17811812
return costs.mean()

src/torchaudio/transforms/_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,8 @@ def forward(self, specgram: Tensor, mask_value: Union[float, torch.Tensor] = 0.0
12021202
specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p
12031203
)
12041204
else:
1205-
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis + specgram.dim() - 3, p=self.p)
1205+
mask_value_ = float(mask_value) if isinstance(mask_value, Tensor) else mask_value
1206+
return F.mask_along_axis(specgram, self.mask_param, mask_value_, self.axis + specgram.dim() - 3, p=self.p)
12061207

12071208

12081209
class FrequencyMasking(_AxisMasking):

0 commit comments

Comments
 (0)