Skip to content

Commit 8115d76

Browse files
committed
Rebase against main
1 parent 81bd50f commit 8115d76

File tree

2 files changed

+17
-35
lines changed

2 files changed

+17
-35
lines changed

src/torchaudio/functional/filtering.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -957,18 +957,19 @@ def backward(ctx, dy):
957957
n_channel = x.size(1)
958958
n_order = b_coeffs.size(1)
959959

960-
db = F.conv1d(
961-
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
962-
dy.view(n_batch * n_channel, 1, -1),
963-
groups=n_batch * n_channel
964-
).view(
965-
n_batch, n_channel, -1
966-
).sum(0).flip(1) if b_coeffs.requires_grad else None
967-
dx = F.conv1d(
968-
F.pad(dy, (0, n_order - 1)),
969-
b_coeffs.unsqueeze(1),
970-
groups=n_channel
971-
) if x.requires_grad else None
960+
db = (
961+
F.conv1d(
962+
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
963+
dy.view(n_batch * n_channel, 1, -1),
964+
groups=n_batch * n_channel,
965+
)
966+
.view(n_batch, n_channel, -1)
967+
.sum(0)
968+
.flip(1)
969+
if b_coeffs.requires_grad
970+
else None
971+
)
972+
dx = F.conv1d(F.pad(dy, (0, n_order - 1)), b_coeffs.unsqueeze(1), groups=n_channel) if x.requires_grad else None
972973
return (dx, db)
973974

974975
@staticmethod

src/torchaudio/functional/functional.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,33 +1735,14 @@ def backward(ctx, dy):
17351735
return (result, None, None, None, None, None, None, None)
17361736

17371737
@staticmethod
1738-
def ts_apply(
1739-
logits,
1740-
targets,
1741-
logit_lengths,
1742-
target_lengths,
1743-
blank: int,
1744-
clamp: float,
1745-
fused_log_softmax: bool):
1738+
def ts_apply(logits, targets, logit_lengths, target_lengths, blank: int, clamp: float, fused_log_softmax: bool):
17461739
if torch.jit.is_scripting():
17471740
output, saved = torch.ops.torchaudio.rnnt_loss_forward(
1748-
logits,
1749-
targets,
1750-
logit_lengths,
1751-
target_lengths,
1752-
blank,
1753-
clamp,
1754-
fused_log_softmax)
1741+
logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax
1742+
)
17551743
return output
17561744
else:
1757-
return RnntLoss.apply(
1758-
logits,
1759-
targets,
1760-
logit_lengths,
1761-
target_lengths,
1762-
blank,
1763-
clamp,
1764-
fused_log_softmax)
1745+
return RnntLoss.apply(logits, targets, logit_lengths, target_lengths, blank, clamp, fused_log_softmax)
17651746

17661747

17671748
def _rnnt_loss(

0 commit comments

Comments
 (0)