Skip to content

Commit e9a52a3

Browse files
authored
[STABLE ABI] Port overdrive (#4131)
* Port overdrive * Use TORCH_BOX in rnnt
1 parent ee1a135 commit e9a52a3

File tree

2 files changed

+71
-55
lines changed

2 files changed

+71
-55
lines changed

src/libtorchaudio/overdrive.cpp

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,78 @@
1-
#include <torch/script.h>
2-
#include <torch/torch.h>
1+
#include <torch/csrc/stable/library.h>
2+
#include <torch/csrc/stable/ops.h>
3+
#include <torch/csrc/stable/tensor.h>
4+
#include <torch/headeronly/core/Dispatch_v2.h>
5+
#include <torch/headeronly/core/TensorAccessor.h>
36

47
namespace {
58

9+
using torch::stable::Tensor;
10+
11+
template <typename T, size_t N>
12+
using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
13+
14+
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
15+
// after Tensor::accessor is supported in stable ABI
16+
template <typename T, size_t N>
17+
inline TensorAccessor<T, N> accessor(Tensor t) {
18+
return TensorAccessor<T, N>(
19+
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
20+
}
21+
622
template <typename scalar_t>
723
void overdrive_cpu_kernel(
8-
at::TensorAccessor<scalar_t, 2> waveform_accessor,
9-
at::TensorAccessor<scalar_t, 2> temp_accessor,
10-
at::TensorAccessor<scalar_t, 1> last_in_accessor,
11-
at::TensorAccessor<scalar_t, 1> last_out_accessor,
12-
at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
24+
TensorAccessor<scalar_t, 2> waveform_accessor,
25+
TensorAccessor<scalar_t, 2> temp_accessor,
26+
TensorAccessor<scalar_t, 1> last_in_accessor,
27+
TensorAccessor<scalar_t, 1> last_out_accessor,
28+
TensorAccessor<scalar_t, 2> output_waveform_accessor) {
1329
int64_t n_frames = waveform_accessor.size(1);
1430
int64_t n_channels = waveform_accessor.size(0);
1531

16-
at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) {
17-
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
18-
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
19-
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
20-
last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
21-
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
22-
output_waveform_accessor[i_channel][i_frame] =
23-
waveform_accessor[i_channel][i_frame] * 0.5 +
24-
last_out_accessor[i_channel] * 0.75;
25-
}
26-
}
27-
});
32+
torch::stable::parallel_for(
33+
0, n_channels, 1, [&](int64_t begin, int64_t end) {
34+
for (int64_t i_channel = begin; i_channel < end; ++i_channel) {
35+
for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
36+
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
37+
last_in_accessor[i_channel] +
38+
0.995 * last_out_accessor[i_channel];
39+
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
40+
output_waveform_accessor[i_channel][i_frame] =
41+
waveform_accessor[i_channel][i_frame] * 0.5 +
42+
last_out_accessor[i_channel] * 0.75;
43+
}
44+
}
45+
});
2846
}
2947

30-
void overdrive_core_loop_cpu(
31-
at::Tensor& waveform,
32-
at::Tensor& temp,
33-
at::Tensor& last_in,
34-
at::Tensor& last_out,
35-
at::Tensor& output_waveform) {
36-
AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
37-
overdrive_cpu_kernel<scalar_t>(
38-
waveform.accessor<scalar_t, 2>(),
39-
temp.accessor<scalar_t, 2>(),
40-
last_in.accessor<scalar_t, 1>(),
41-
last_out.accessor<scalar_t, 1>(),
42-
output_waveform.accessor<scalar_t, 2>());
43-
}));
48+
std::tuple<Tensor, Tensor, Tensor> overdrive_core_loop_cpu(
49+
Tensor waveform,
50+
Tensor temp,
51+
Tensor last_in,
52+
Tensor last_out,
53+
Tensor output_waveform) {
54+
THO_DISPATCH_V2(
55+
waveform.scalar_type(),
56+
"overdrive_cpu",
57+
AT_WRAP([&] {
58+
overdrive_cpu_kernel<scalar_t>(
59+
accessor<scalar_t, 2>(waveform),
60+
accessor<scalar_t, 2>(temp),
61+
accessor<scalar_t, 1>(last_in),
62+
accessor<scalar_t, 1>(last_out),
63+
accessor<scalar_t, 2>(output_waveform));
64+
}),
65+
AT_FLOATING_TYPES);
66+
return std::make_tuple(last_in, last_out, output_waveform);
4467
}
4568

4669
} // namespace
4770

48-
// Note: We want to avoid using "catch-all" kernel.
49-
// The following registration should be replaced with CPU specific registration.
50-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
51-
m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu);
71+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
72+
m.def(
73+
"_overdrive_core_loop(Tensor waveform, Tensor temp, Tensor(a!) last_in, Tensor(b!) last_out, Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))");
74+
}
75+
76+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
77+
m.impl("_overdrive_core_loop", TORCH_BOX(&overdrive_core_loop_cpu));
5278
}

src/libtorchaudio/rnnt/gpu/compute.cu

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ using torch::headeronly::ScalarType;
1414

1515
// Entry point into RNNT Loss
1616
std::tuple<Tensor, Tensor> compute(
17-
const Tensor& logits,
18-
const Tensor& targets,
19-
const Tensor& logit_lengths,
20-
const Tensor& target_lengths,
17+
Tensor logits,
18+
Tensor targets,
19+
Tensor logit_lengths,
20+
Tensor target_lengths,
2121
int64_t blank,
2222
double clamp,
2323
bool fused_log_softmax = true) {
@@ -148,23 +148,13 @@ std::tuple<Tensor, Tensor> compute(
148148
return std::make_tuple(costs, gradients);
149149
}
150150

151-
void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
152-
STD_TORCH_CHECK(num_args == 7, "num_args must be 7");
153-
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
154-
std::tuple<Tensor, Tensor> res = compute(
155-
/*logits*/torch::stable::detail::to<Tensor>(stack[0]),
156-
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
157-
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
158-
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
159-
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])),
160-
/*clamp*/torch::stable::detail::to<double>(stack[5]),
161-
/*fused_log_softmax*/torch::stable::detail::to<bool>(stack[6]));
162-
stack[0] = torch::stable::detail::from(std::get<0>(res));
163-
stack[1] = torch::stable::detail::from(std::get<1>(res));
151+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
152+
m.def(
153+
"rnnt_loss_forward(Tensor logits, Tensor targets, Tensor logit_lengths, Tensor target_lengths, int blank, double clamp, bool fused_log_softmax) -> (Tensor, Tensor)");
164154
}
165155

166156
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
167-
m.impl("rnnt_loss_forward", &boxed_rnnt_loss);
157+
m.impl("rnnt_loss_forward", TORCH_BOX(&compute));
168158
}
169159

170160
} // namespace gpu

0 commit comments

Comments
 (0)