Skip to content

Commit 05bff83

Browse files
authored
Add C++ lfilter core loop for CPU (#1244)
1 parent c3cb201 commit 05bff83

File tree

3 files changed

+96
-6
lines changed

3 files changed

+96
-6
lines changed

torchaudio/csrc/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ set(
1010
sox/effects.cpp
1111
sox/effects_chain.cpp
1212
sox/types.cpp
13+
lfilter.cpp
1314
)
1415

1516
if(BUILD_TRANSDUCER)

torchaudio/csrc/lfilter.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#include <torch/script.h>
2+
3+
namespace {
4+
5+
template <typename scalar_t>
6+
void host_lfilter_core_loop(
7+
const torch::Tensor& input_signal_windows,
8+
const torch::Tensor& a_coeff_flipped,
9+
torch::Tensor& padded_output_waveform) {
10+
int64_t n_channel = input_signal_windows.size(0);
11+
int64_t n_samples_input = input_signal_windows.size(1);
12+
int64_t n_samples_output = padded_output_waveform.size(1);
13+
int64_t n_order = a_coeff_flipped.size(0);
14+
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
15+
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
16+
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
17+
for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) {
18+
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
19+
int64_t offset_input = i_channel * n_samples_input;
20+
int64_t offset_output = i_channel * n_samples_output;
21+
scalar_t a0 = input_data[offset_input + i_sample];
22+
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
23+
a0 -= output_data[offset_output + i_sample + i_coeff] *
24+
a_coeff_flipped_data[i_coeff];
25+
}
26+
output_data[offset_output + i_sample + n_order - 1] = a0;
27+
}
28+
}
29+
}
30+
31+
void cpu_lfilter_core_loop(
32+
const torch::Tensor& input_signal_windows,
33+
const torch::Tensor& a_coeff_flipped,
34+
torch::Tensor& padded_output_waveform) {
35+
TORCH_CHECK(
36+
input_signal_windows.device().is_cpu() &&
37+
a_coeff_flipped.device().is_cpu() &&
38+
padded_output_waveform.device().is_cpu());
39+
40+
TORCH_CHECK(
41+
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
42+
padded_output_waveform.is_contiguous());
43+
44+
TORCH_CHECK(
45+
(input_signal_windows.dtype() == torch::kFloat32 ||
46+
input_signal_windows.dtype() == torch::kFloat64) &&
47+
(a_coeff_flipped.dtype() == torch::kFloat32 ||
48+
a_coeff_flipped.dtype() == torch::kFloat64) &&
49+
(padded_output_waveform.dtype() == torch::kFloat32 ||
50+
padded_output_waveform.dtype() == torch::kFloat64));
51+
52+
TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
53+
54+
TORCH_CHECK(
55+
input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 ==
56+
padded_output_waveform.size(1));
57+
58+
AT_DISPATCH_FLOATING_TYPES(
59+
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
60+
host_lfilter_core_loop<scalar_t>(
61+
input_signal_windows, a_coeff_flipped, padded_output_waveform);
62+
});
63+
}
64+
65+
} // namespace
66+
67+
// Note: We want to avoid using "catch-all" kernel.
68+
// The following registration should be replaced with CPU specific registration.
69+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
70+
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
71+
}

torchaudio/functional/filtering.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,23 @@ def highpass_biquad(
808808
return biquad(waveform, b0, b1, b2, a0, a1, a2)
809809

810810

811+
def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: Tensor, padded_output_waveform: Tensor):
812+
n_order = a_coeffs_flipped.size(0)
813+
for i_sample, o0 in enumerate(input_signal_windows.t()):
814+
windowed_output_signal = padded_output_waveform[
815+
:, i_sample:i_sample + n_order
816+
]
817+
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
818+
padded_output_waveform[:, i_sample + n_order - 1] = o0
819+
820+
821+
try:
822+
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
823+
except RuntimeError as err:
824+
assert str(err) == 'No such operator torchaudio::_lfilter_core_loop'
825+
_lfilter_core_cpu_loop = _lfilter_core_generic_loop
826+
827+
811828
def lfilter(
812829
waveform: Tensor,
813830
a_coeffs: Tensor,
@@ -877,12 +894,13 @@ def lfilter(
877894

878895
input_signal_windows.div_(a_coeffs[0])
879896
a_coeffs_flipped.div_(a_coeffs[0])
880-
for i_sample, o0 in enumerate(input_signal_windows.t()):
881-
windowed_output_signal = padded_output_waveform[
882-
:, i_sample:i_sample + n_order
883-
]
884-
o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
885-
padded_output_waveform[:, i_sample + n_order - 1] = o0
897+
898+
if input_signal_windows.device == torch.device('cpu') and\
899+
a_coeffs_flipped.device == torch.device('cpu') and\
900+
padded_output_waveform.device == torch.device('cpu'):
901+
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
902+
else:
903+
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
886904

887905
output = padded_output_waveform[:, n_order - 1:]
888906

0 commit comments

Comments
 (0)