|
| 1 | +#include <torch/script.h> |
| 2 | + |
| 3 | +namespace { |
| 4 | + |
| 5 | +template <typename scalar_t> |
| 6 | +void host_lfilter_core_loop( |
| 7 | + const torch::Tensor& input_signal_windows, |
| 8 | + const torch::Tensor& a_coeff_flipped, |
| 9 | + torch::Tensor& padded_output_waveform) { |
| 10 | + int64_t n_channel = input_signal_windows.size(0); |
| 11 | + int64_t n_samples_input = input_signal_windows.size(1); |
| 12 | + int64_t n_samples_output = padded_output_waveform.size(1); |
| 13 | + int64_t n_order = a_coeff_flipped.size(0); |
| 14 | + scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>(); |
| 15 | + const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>(); |
| 16 | + const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>(); |
| 17 | + for (int64_t i_channel = 0; i_channel < n_channel; i_channel++) { |
| 18 | + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { |
| 19 | + int64_t offset_input = i_channel * n_samples_input; |
| 20 | + int64_t offset_output = i_channel * n_samples_output; |
| 21 | + scalar_t a0 = input_data[offset_input + i_sample]; |
| 22 | + for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { |
| 23 | + a0 -= output_data[offset_output + i_sample + i_coeff] * |
| 24 | + a_coeff_flipped_data[i_coeff]; |
| 25 | + } |
| 26 | + output_data[offset_output + i_sample + n_order - 1] = a0; |
| 27 | + } |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +void cpu_lfilter_core_loop( |
| 32 | + const torch::Tensor& input_signal_windows, |
| 33 | + const torch::Tensor& a_coeff_flipped, |
| 34 | + torch::Tensor& padded_output_waveform) { |
| 35 | + TORCH_CHECK( |
| 36 | + input_signal_windows.device().is_cpu() && |
| 37 | + a_coeff_flipped.device().is_cpu() && |
| 38 | + padded_output_waveform.device().is_cpu()); |
| 39 | + |
| 40 | + TORCH_CHECK( |
| 41 | + input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() && |
| 42 | + padded_output_waveform.is_contiguous()); |
| 43 | + |
| 44 | + TORCH_CHECK( |
| 45 | + (input_signal_windows.dtype() == torch::kFloat32 || |
| 46 | + input_signal_windows.dtype() == torch::kFloat64) && |
| 47 | + (a_coeff_flipped.dtype() == torch::kFloat32 || |
| 48 | + a_coeff_flipped.dtype() == torch::kFloat64) && |
| 49 | + (padded_output_waveform.dtype() == torch::kFloat32 || |
| 50 | + padded_output_waveform.dtype() == torch::kFloat64)); |
| 51 | + |
| 52 | + TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); |
| 53 | + |
| 54 | + TORCH_CHECK( |
| 55 | + input_signal_windows.size(1) + a_coeff_flipped.size(0) - 1 == |
| 56 | + padded_output_waveform.size(1)); |
| 57 | + |
| 58 | + AT_DISPATCH_FLOATING_TYPES( |
| 59 | + input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { |
| 60 | + host_lfilter_core_loop<scalar_t>( |
| 61 | + input_signal_windows, a_coeff_flipped, padded_output_waveform); |
| 62 | + }); |
| 63 | +} |
| 64 | + |
| 65 | +} // namespace |
| 66 | + |
| 67 | +// Note: We want to avoid using "catch-all" kernel. |
| 68 | +// The following registration should be replaced with CPU specific registration. |
| 69 | +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { |
| 70 | + m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); |
| 71 | +} |
0 commit comments