Skip to content

Commit 15193c1

Browse files
Port autograd parts of lfilter to python (#3954)
Co-authored-by: Sam Anklesaria <[email protected]> Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent b2a6973 commit 15193c1

File tree

4 files changed

+82
-249
lines changed

4 files changed

+82
-249
lines changed

.github/scripts/unittest-linux/run_test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ fi
3030

3131
(
3232
cd test
33-
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt)"
33+
pytest torchaudio_unittest -k "not backend and not /io/ and not prototype and not sox and not ffmpeg and not fairseq and not hdemucs and not (torchscript and rnnt) and not torchscript_consistency"
3434
)

.github/workflows/unittest-linux-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ jobs:
117117
'--cov=torchaudio'
118118
"--junitxml=${RUNNER_TEST_RESULTS_DIR}/junit.xml"
119119
'--durations' '100'
120-
'-k' '(cuda or gpu) and not (torchscript and rnnt)'
120+
'-k' '(cuda or gpu) and not (torchscript and rnnt) and not torchscript_consistency'
121121
)
122122
123123
cd test

src/libtorchaudio/lfilter.cpp

Lines changed: 12 additions & 183 deletions
Original file line numberDiff line numberDiff line change
@@ -100,194 +100,23 @@ void lfilter_core_generic_loop(
100100
}
101101
}
102102

103-
class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
104-
public:
105-
static torch::Tensor forward(
106-
torch::autograd::AutogradContext* ctx,
107-
const torch::Tensor& waveform,
108-
const torch::Tensor& a_coeffs_normalized) {
109-
auto device = waveform.device();
110-
auto dtype = waveform.dtype();
111-
int64_t n_batch = waveform.size(0);
112-
int64_t n_channel = waveform.size(1);
113-
int64_t n_sample = waveform.size(2);
114-
int64_t n_order = a_coeffs_normalized.size(1);
115-
int64_t n_sample_padded = n_sample + n_order - 1;
116-
117-
auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
118-
119-
auto options = torch::TensorOptions().dtype(dtype).device(device);
120-
auto padded_output_waveform =
121-
torch::zeros({n_batch, n_channel, n_sample_padded}, options);
122-
123-
if (device.is_cpu()) {
124-
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
125-
} else if (device.is_cuda()) {
126-
#ifdef USE_CUDA
127-
cuda_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
128-
#else
129-
lfilter_core_generic_loop(
130-
waveform, a_coeff_flipped, padded_output_waveform);
131-
#endif
132-
} else {
133-
lfilter_core_generic_loop(
134-
waveform, a_coeff_flipped, padded_output_waveform);
135-
}
136-
137-
auto output = padded_output_waveform.index(
138-
{torch::indexing::Slice(),
139-
torch::indexing::Slice(),
140-
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
141-
142-
ctx->save_for_backward({waveform, a_coeffs_normalized, output});
143-
return output;
144-
}
145-
146-
static torch::autograd::tensor_list backward(
147-
torch::autograd::AutogradContext* ctx,
148-
torch::autograd::tensor_list grad_outputs) {
149-
auto saved = ctx->get_saved_variables();
150-
auto x = saved[0];
151-
auto a_coeffs_normalized = saved[1];
152-
auto y = saved[2];
153-
154-
int64_t n_channel = x.size(1);
155-
int64_t n_order = a_coeffs_normalized.size(1);
156-
157-
auto dx = torch::Tensor();
158-
auto da = torch::Tensor();
159-
auto dy = grad_outputs[0];
160-
161-
namespace F = torch::nn::functional;
162-
163-
auto tmp =
164-
DifferentiableIIR::apply(dy.flip(2).contiguous(), a_coeffs_normalized)
165-
.flip(2);
166-
167-
if (x.requires_grad()) {
168-
dx = tmp;
169-
}
170-
171-
if (a_coeffs_normalized.requires_grad()) {
172-
da = -torch::matmul(
173-
tmp.transpose(0, 1).reshape({n_channel, 1, -1}),
174-
F::pad(y, F::PadFuncOptions({n_order - 1, 0}))
175-
.unfold(2, n_order, 1)
176-
.transpose(0, 1)
177-
.reshape({n_channel, -1, n_order}))
178-
.squeeze(1)
179-
.flip(1);
180-
}
181-
return {dx, da};
182-
}
183-
};
184-
185-
class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
186-
public:
187-
static torch::Tensor forward(
188-
torch::autograd::AutogradContext* ctx,
189-
const torch::Tensor& waveform,
190-
const torch::Tensor& b_coeffs) {
191-
int64_t n_order = b_coeffs.size(1);
192-
int64_t n_channel = b_coeffs.size(0);
193-
194-
namespace F = torch::nn::functional;
195-
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
196-
auto padded_waveform =
197-
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));
198-
199-
auto output = F::conv1d(
200-
padded_waveform,
201-
b_coeff_flipped.unsqueeze(1),
202-
F::Conv1dFuncOptions().groups(n_channel));
203-
204-
ctx->save_for_backward({waveform, b_coeffs, output});
205-
return output;
206-
}
207-
208-
static torch::autograd::tensor_list backward(
209-
torch::autograd::AutogradContext* ctx,
210-
torch::autograd::tensor_list grad_outputs) {
211-
auto saved = ctx->get_saved_variables();
212-
auto x = saved[0];
213-
auto b_coeffs = saved[1];
214-
auto y = saved[2];
215-
216-
int64_t n_batch = x.size(0);
217-
int64_t n_channel = x.size(1);
218-
int64_t n_order = b_coeffs.size(1);
219-
220-
auto dx = torch::Tensor();
221-
auto db = torch::Tensor();
222-
auto dy = grad_outputs[0];
223-
224-
namespace F = torch::nn::functional;
225-
226-
if (b_coeffs.requires_grad()) {
227-
db = F::conv1d(
228-
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
229-
.view({1, n_batch * n_channel, -1}),
230-
dy.view({n_batch * n_channel, 1, -1}),
231-
F::Conv1dFuncOptions().groups(n_batch * n_channel))
232-
.view({n_batch, n_channel, -1})
233-
.sum(0)
234-
.flip(1);
235-
}
236-
237-
if (x.requires_grad()) {
238-
dx = F::conv1d(
239-
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
240-
b_coeffs.unsqueeze(1),
241-
F::Conv1dFuncOptions().groups(n_channel));
242-
}
243-
244-
return {dx, db};
245-
}
246-
};
247-
248-
torch::Tensor lfilter_core(
249-
const torch::Tensor& waveform,
250-
const torch::Tensor& a_coeffs,
251-
const torch::Tensor& b_coeffs) {
252-
TORCH_CHECK(waveform.device() == a_coeffs.device());
253-
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
254-
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());
255-
256-
TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
257-
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
258-
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));
259-
260-
int64_t n_order = b_coeffs.size(1);
261-
262-
TORCH_INTERNAL_ASSERT(n_order > 0);
263-
264-
auto filtered_waveform = DifferentiableFIR::apply(
265-
waveform,
266-
b_coeffs /
267-
a_coeffs.index(
268-
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
103+
} // namespace
269104

270-
auto output = DifferentiableIIR::apply(
271-
filtered_waveform,
272-
a_coeffs /
273-
a_coeffs.index(
274-
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
275-
return output;
105+
TORCH_LIBRARY(torchaudio, m) {
106+
m.def(
107+
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
276108
}
277109

278-
} // namespace
279-
280-
// Note: We want to avoid using "catch-all" kernel.
281-
// The following registration should be replaced with CPU specific registration.
282-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
283-
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
110+
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
111+
m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
284112
}
285113

286-
TORCH_LIBRARY(torchaudio, m) {
287-
m.def(
288-
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
114+
#ifdef USE_CUDA
115+
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
116+
m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
289117
}
118+
#endif
290119

291-
TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
292-
m.impl("torchaudio::_lfilter", lfilter_core);
120+
TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
121+
m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop);
293122
}

src/torchaudio/functional/filtering.py

Lines changed: 68 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66
from torch import Tensor
7+
import torch.nn.functional as F
78

89
from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE
910

@@ -932,70 +933,74 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T
932933

933934

934935
if _IS_TORCHAUDIO_EXT_AVAILABLE:
935-
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
936+
_lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop
936937
else:
937-
_lfilter_core_cpu_loop = _lfilter_core_generic_loop
938-
939-
940-
def _lfilter_core(
941-
waveform: Tensor,
942-
a_coeffs: Tensor,
943-
b_coeffs: Tensor,
944-
) -> Tensor:
945-
946-
if a_coeffs.size() != b_coeffs.size():
947-
raise ValueError(
948-
"Expected coeffs to be the same size."
949-
f"Found a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
950-
)
951-
if waveform.ndim != 3:
952-
raise ValueError(f"Expected waveform to be 3 dimensional. Found: {waveform.ndim}")
953-
if not (waveform.device == a_coeffs.device == b_coeffs.device):
954-
raise ValueError(
955-
"Expected waveform and coeffs to be on the same device."
956-
f"Found: waveform device:{waveform.device}, a_coeffs device: {a_coeffs.device}, "
957-
f"b_coeffs device: {b_coeffs.device}"
958-
)
959-
960-
n_batch, n_channel, n_sample = waveform.size()
961-
n_order = a_coeffs.size(1)
962-
if n_order <= 0:
963-
raise ValueError(f"Expected n_order to be positive. Found: {n_order}")
964-
965-
# Pad the input and create output
966-
967-
padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0])
968-
padded_output_waveform = torch.zeros_like(padded_waveform)
969-
970-
# Set up the coefficients matrix
971-
# Flip coefficients' order
972-
a_coeffs_flipped = a_coeffs.flip(1)
973-
b_coeffs_flipped = b_coeffs.flip(1)
974-
975-
# calculate windowed_input_signal in parallel using convolution
976-
input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel)
977-
978-
input_signal_windows.div_(a_coeffs[:, :1])
979-
a_coeffs_flipped.div_(a_coeffs[:, :1])
980-
981-
if (
982-
input_signal_windows.device == torch.device("cpu")
983-
and a_coeffs_flipped.device == torch.device("cpu")
984-
and padded_output_waveform.device == torch.device("cpu")
985-
):
986-
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
987-
else:
988-
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
989-
990-
output = padded_output_waveform[:, :, n_order - 1 :]
991-
return output
992-
993-
994-
if _IS_TORCHAUDIO_EXT_AVAILABLE:
995-
_lfilter = torch.ops.torchaudio._lfilter
996-
else:
997-
_lfilter = _lfilter_core
998-
938+
_lfilter_core_loop = _lfilter_core_generic_loop
939+
940+
941+
class DifferentiableFIR(torch.autograd.Function):
942+
@staticmethod
943+
def forward(ctx, waveform, b_coeffs):
944+
n_order = b_coeffs.size(1)
945+
n_channel = b_coeffs.size(0)
946+
b_coeff_flipped = b_coeffs.flip(1).contiguous()
947+
padded_waveform = F.pad(waveform, (n_order - 1, 0))
948+
output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
949+
ctx.save_for_backward(waveform, b_coeffs, output)
950+
return output
951+
952+
@staticmethod
953+
def backward(ctx, dy):
954+
x, b_coeffs, y = ctx.saved_tensors
955+
n_batch = x.size(0)
956+
n_channel = x.size(1)
957+
n_order = b_coeffs.size(1)
958+
db = F.conv1d(
959+
F.pad(x, (n_order - 1, 0)).view(1, n_batch * n_channel, -1),
960+
dy.view(n_batch * n_channel, 1, -1),
961+
groups=n_batch * n_channel
962+
).view(
963+
n_batch, n_channel, -1
964+
).sum(0).flip(1) if b_coeffs.requires_grad else None
965+
dx = F.conv1d(
966+
F.pad(dy, (0, n_order - 1)),
967+
b_coeffs.unsqueeze(1),
968+
groups=n_channel
969+
) if x.requires_grad else None
970+
return (dx, db)
971+
972+
class DifferentiableIIR(torch.autograd.Function):
973+
@staticmethod
974+
def forward(ctx, waveform, a_coeffs_normalized):
975+
n_batch, n_channel, n_sample = waveform.shape
976+
n_order = a_coeffs_normalized.size(1)
977+
n_sample_padded = n_sample + n_order - 1
978+
979+
a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
980+
padded_output_waveform = torch.zeros(n_batch, n_channel, n_sample_padded,
981+
device=waveform.device, dtype=waveform.dtype)
982+
_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
983+
output = padded_output_waveform[:,:,n_order - 1:]
984+
ctx.save_for_backward(waveform, a_coeffs_normalized, output)
985+
return output
986+
987+
@staticmethod
988+
def backward(ctx, dy):
989+
x, a_coeffs_normalized, y = ctx.saved_tensors
990+
n_channel = x.size(1)
991+
n_order = a_coeffs_normalized.size(1)
992+
tmp = DifferentiableIIR.apply(dy.flip(2).contiguous(), a_coeffs_normalized).flip(2)
993+
dx = tmp if x.requires_grad else None
994+
da = -(tmp.transpose(0, 1).reshape(n_channel, 1, -1) @
995+
F.pad(y, (n_order - 1, 0)).unfold(2, n_order, 1).transpose(0,1)
996+
.reshape(n_channel, -1, n_order)
997+
).squeeze(1).flip(1) if a_coeffs_normalized.requires_grad else None
998+
return (dx, da)
999+
1000+
def _lfilter(waveform, a_coeffs, b_coeffs):
1001+
n_order = b_coeffs.size(1)
1002+
filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
1003+
return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])
9991004

10001005
def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
10011006
r"""Perform an IIR filter by evaluating difference equation, using differentiable implementation
@@ -1066,7 +1071,6 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =
10661071

10671072
return output
10681073

1069-
10701074
def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
10711075
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
10721076

0 commit comments

Comments
 (0)