diff --git a/src/libtorchaudio/overdrive.cpp b/src/libtorchaudio/overdrive.cpp index 4954271e41..387e4a7f78 100644 --- a/src/libtorchaudio/overdrive.cpp +++ b/src/libtorchaudio/overdrive.cpp @@ -1,52 +1,98 @@ #include #include +#include +#include +#include +#include +#include +#include -namespace { +using namespace std; + +namespace torchaudio { + +using torch::stable::Tensor; template void overdrive_cpu_kernel( - at::TensorAccessor waveform_accessor, - at::TensorAccessor temp_accessor, - at::TensorAccessor last_in_accessor, - at::TensorAccessor last_out_accessor, - at::TensorAccessor output_waveform_accessor) { + Accessor<2, scalar_t> waveform_accessor, + Accessor<2, scalar_t> temp_accessor, + Accessor<1, scalar_t, false> last_in_accessor, + Accessor<1, scalar_t, false> last_out_accessor, + Accessor<2, scalar_t, false> output_waveform_accessor) { int64_t n_frames = waveform_accessor.size(1); int64_t n_channels = waveform_accessor.size(0); at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) { for (int64_t i_channel = begin; i_channel < end; ++i_channel) { for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { - last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - - last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; - last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; - output_waveform_accessor[i_channel][i_frame] = - waveform_accessor[i_channel][i_frame] * 0.5 + - last_out_accessor[i_channel] * 0.75; + last_out_accessor.set_index( + temp_accessor.index(i_channel, i_frame) - + last_in_accessor.index(i_channel) + 0.995 * last_out_accessor.index(i_channel), + i_channel); + last_in_accessor.set_index(temp_accessor.index(i_channel, i_frame), i_channel); + output_waveform_accessor.set_index( + waveform_accessor.index(i_channel, i_frame) * 0.5 + + last_out_accessor.index(i_channel) * 0.75, + i_channel, i_frame); } } }); } void overdrive_core_loop_cpu( - at::Tensor& waveform, - at::Tensor& temp, - at::Tensor& last_in, - at::Tensor& last_out, - at::Tensor& output_waveform) { - AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { - overdrive_cpu_kernel( - waveform.accessor(), - temp.accessor(), - last_in.accessor(), - last_out.accessor(), - output_waveform.accessor()); - })); + const Tensor waveform, + const Tensor temp, + Tensor last_in, + Tensor last_out, + Tensor output_waveform) { + int32_t dtype; + aoti_torch_get_dtype(waveform.get(), &dtype); + if (dtype == aoti_torch_dtype_float64()) { + overdrive_cpu_kernel( + Accessor<2, double>(waveform), + Accessor<2, double>(temp), + Accessor<1, double, false>(last_in), + Accessor<1, double, false>(last_out), + Accessor<2, double, false>(output_waveform)); + } else if (dtype == aoti_torch_dtype_float32()) { + overdrive_cpu_kernel( + Accessor<2, float>(waveform), + Accessor<2, float>(temp), + Accessor<1, float, false>(last_in), + Accessor<1, float, false>(last_out), + Accessor<2, float, false>(output_waveform)); + } else if (dtype == aoti_torch_dtype_float16()) { + overdrive_cpu_kernel( + Accessor<2, c10::Half>(waveform), + Accessor<2, c10::Half>(temp), + Accessor<1, c10::Half, false>(last_in), + Accessor<1, c10::Half, false>(last_out), + Accessor<2, c10::Half, false>(output_waveform)); + } } -} // namespace -// Note: We want to avoid using "catch-all" kernel. -// The following registration should be replaced with CPU specific registration. -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu); + +void boxed_overdrive_core_loop(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); + Tensor t5(to(stack[4])); + overdrive_core_loop_cpu( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), std::move(t5)); } + +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "overdrive_core_loop(Tensor waveform," + "Tensor temp, Tensor last_in, Tensor last_out," + "Tensor output_waveform) -> ()"); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("overdrive_core_loop", &boxed_overdrive_core_loop); +} + +} // namespace diff --git a/src/torchaudio/functional/filtering.py b/src/torchaudio/functional/filtering.py index 76deb04a96..662e59cb95 100644 --- a/src/torchaudio/functional/filtering.py +++ b/src/torchaudio/functional/filtering.py @@ -1114,7 +1114,7 @@ def _overdrive_core_loop_generic( if _IS_TORCHAUDIO_EXT_AVAILABLE: - _overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop + _overdrive_core_loop_cpu = torch.ops.torchaudio.overdrive_core_loop.default else: _overdrive_core_loop_cpu = _overdrive_core_loop_generic