Skip to content

Commit 32ce8c0

Browse files
authored
Update forced_align tensor accessors to use headeronly. (#4134)
* Update forced_align tensor accessors to use headeronly. * Use TORCH_BOX in rnnt/cpu. * Use TORCH_BOX in forced_align/cpu. * Eliminate libtorchaudio/stable/TensorAccessor.h
1 parent e9a52a3 commit 32ce8c0

File tree

10 files changed

+107
-527
lines changed

10 files changed

+107
-527
lines changed
Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,10 @@
1-
#include <libtorchaudio/forced_align/compute.h>
2-
#include <torch/script.h>
1+
#include <torch/csrc/stable/library.h>
32

4-
std::tuple<torch::Tensor, torch::Tensor> forced_align(
5-
const torch::Tensor& logProbs,
6-
const torch::Tensor& targets,
7-
const torch::Tensor& inputLengths,
8-
const torch::Tensor& targetLengths,
9-
const int64_t blank) {
10-
static auto op = torch::Dispatcher::singleton()
11-
.findSchemaOrThrow("torchaudio::forced_align", "")
12-
.typed<decltype(forced_align)>();
13-
return op.call(logProbs, targets, inputLengths, targetLengths, blank);
14-
}
15-
16-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
174
m.def(
18-
"forced_align(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> (Tensor, Tensor)");
5+
"forced_align(Tensor log_probs,"
6+
"Tensor targets,"
7+
"Tensor input_lengths,"
8+
"Tensor target_lengths,"
9+
"int blank) -> (Tensor, Tensor)");
1910
}
Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1 @@
11
#pragma once
2-
3-
#include <torch/script.h>
4-
5-
std::tuple<torch::Tensor, torch::Tensor> forced_align(
6-
const torch::Tensor& logProbs,
7-
const torch::Tensor& targets,
8-
const torch::Tensor& inputLengths,
9-
const torch::Tensor& targetLengths,
10-
const int64_t blank);

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ void forced_align_impl(
3838
for (int i = 0; i < T * S; i++) {
3939
backPtr_a[i] = -1;
4040
}
41-
auto logProbs_a = torchaudio::stable::accessor<scalar_t, 3>(logProbs);
42-
auto targets_a = torchaudio::stable::accessor<target_t, 2>(targets);
43-
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
41+
auto logProbs_a = torchaudio::accessor<scalar_t, 3>(logProbs);
42+
auto targets_a = torchaudio::accessor<target_t, 2>(targets);
43+
auto paths_a = torchaudio::accessor<target_t, 2>(paths);
4444
auto R = 0;
4545
for (auto i = 1; i < L; i++) {
4646
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
@@ -147,10 +147,10 @@ template <typename scalar_t>
147147
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
148148

149149
std::tuple<Tensor, Tensor> compute(
150-
const Tensor& logProbs,
151-
const Tensor& targets,
152-
const Tensor& inputLengths,
153-
const Tensor& targetLengths,
150+
Tensor logProbs,
151+
Tensor targets,
152+
Tensor inputLengths,
153+
Tensor targetLengths,
154154
const int64_t blank) {
155155
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
156156
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
@@ -224,24 +224,8 @@ std::tuple<Tensor, Tensor> compute(
224224
return std::make_tuple(paths, logProbs);
225225
}
226226

227-
void boxed_forced_align_cpu(
228-
StableIValue* stack,
229-
uint64_t num_args,
230-
uint64_t num_outputs) {
231-
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
232-
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
233-
std::tuple<Tensor, Tensor> res = compute(
234-
/*logProbs*/ torch::stable::detail::to<Tensor>(stack[0]),
235-
/*targets*/ torch::stable::detail::to<Tensor>(stack[1]),
236-
/*logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2]),
237-
/*target_lengths*/ torch::stable::detail::to<Tensor>(stack[3]),
238-
/*blank*/ float(torch::stable::detail::to<int64_t>(stack[4])));
239-
stack[0] = torch::stable::detail::from(std::get<0>(res));
240-
stack[1] = torch::stable::detail::from(std::get<1>(res));
241-
}
242-
243227
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
244-
m.impl("forced_align", &boxed_forced_align_cpu);
228+
m.impl("forced_align", TORCH_BOX(&compute));
245229
}
246230

247231
} // namespace cpu

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 20 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <libtorchaudio/utils.h>
2-
#include <libtorchaudio/stable/TensorAccessor.h>
32
#include <torch/csrc/stable/library.h>
43
#include <torch/headeronly/core/Dispatch_v2.h>
54
#include <torch/headeronly/core/ScalarType.h>
@@ -23,9 +22,9 @@ using torch::headeronly::ScalarType;
2322

2423
template <typename scalar_t, typename target_t>
2524
__global__ void falign_cuda_step_kernel(
26-
const torchaudio::stable::PackedTensorAccessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>
25+
const torchaudio::PackedTensorAccessor32<scalar_t, 3>
2726
logProbs_a,
28-
const torchaudio::stable::PackedTensorAccessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>
27+
const torchaudio::PackedTensorAccessor32<target_t, 2>
2928
targets_a,
3029
const int T,
3130
const int L,
@@ -36,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
3635
int start,
3736
int end,
3837
int backPtrBufferLen,
39-
torchaudio::stable::PackedTensorAccessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>
38+
torchaudio::PackedTensorAccessor32<scalar_t, 2>
4039
alphas_a,
41-
torchaudio::stable::PackedTensorAccessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>
40+
torchaudio::PackedTensorAccessor32<int8_t, 2>
4241
backPtrBuffer_a) {
4342
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
4443
const int batchIndex =
@@ -125,7 +124,7 @@ void forced_align_impl(
125124
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
126125
using target_t = typename std::
127126
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
128-
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
127+
auto paths_a = torchaudio::accessor<target_t, 2>(paths);
129128
const int batchIndex =
130129
0; // TODO: support batch version and use the real batch index
131130
const int T = logProbs.size(1); // num frames
@@ -150,8 +149,8 @@ void forced_align_impl(
150149
torch::stable::fill_(alphas, kNegInfinity);
151150

152151
// CPU accessors
153-
auto targetsCpu_a = torchaudio::stable::accessor<target_t, 2>(targetsCpu);
154-
auto backPtrCpu_a = torchaudio::stable::accessor<int8_t, 2>(backPtrCpu);
152+
auto targetsCpu_a = torchaudio::accessor<target_t, 2>(targetsCpu);
153+
auto backPtrCpu_a = torchaudio::accessor<int8_t, 2>(backPtrCpu);
155154
// count the number of repeats in label
156155
int R = 0;
157156
for (int i = 1; i < L; ++i) {
@@ -192,8 +191,8 @@ void forced_align_impl(
192191
}
193192
falign_cuda_step_kernel<scalar_t, target_t>
194193
<<<1, kNumThreads, 0, defaultStream>>>(
195-
torchaudio::stable::packed_accessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>(logProbs),
196-
torchaudio::stable::packed_accessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>(targets),
194+
torchaudio::packed_accessor32<scalar_t, 3>(logProbs),
195+
torchaudio::packed_accessor32<target_t, 2>(targets),
197196
T,
198197
L,
199198
N,
@@ -203,8 +202,8 @@ void forced_align_impl(
203202
start,
204203
end,
205204
backPtrBufferLen,
206-
torchaudio::stable::packed_accessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>(alphas),
207-
torchaudio::stable::packed_accessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>(backPtrBuffer));
205+
torchaudio::packed_accessor32<scalar_t, 2>(alphas),
206+
torchaudio::packed_accessor32<int8_t, 2>(backPtrBuffer));
208207
C10_CUDA_KERNEL_LAUNCH_CHECK();
209208
++backPtrBufferLen;
210209
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
@@ -228,9 +227,8 @@ void forced_align_impl(
228227
}
229228
}
230229
cpuDataTranferStream.synchronize();
231-
232230
auto alphasCpu = torchaudio::stable::cpu(alphas);
233-
auto alphasCpu_a = torchaudio::stable::accessor<scalar_t, 2>(alphasCpu);
231+
auto alphasCpu_a = torchaudio::accessor<scalar_t, 2>(alphasCpu);
234232
int curIdxOffset = ((T - 1) % 2);
235233
int ltrIdx =
236234
alphasCpu_a[curIdxOffset][S - 1] > alphasCpu_a[curIdxOffset][S - 2]
@@ -244,18 +242,11 @@ void forced_align_impl(
244242
}
245243
}
246244

247-
template <typename scalar_t>
248-
const auto forced_align_long_impl =
249-
forced_align_impl<scalar_t, ScalarType::Long>;
250-
251-
template <typename scalar_t>
252-
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
253-
254245
std::tuple<Tensor, Tensor> compute(
255-
const Tensor& logProbs,
256-
const Tensor& targets,
257-
const Tensor& inputLengths,
258-
const Tensor& targetLengths,
246+
Tensor logProbs,
247+
Tensor targets,
248+
Tensor inputLengths,
249+
Tensor targetLengths,
259250
const int64_t blank) {
260251

261252
STD_TORCH_CHECK(logProbs.is_cuda(), "log_probs must be a CUDA tensor");
@@ -307,31 +298,17 @@ std::tuple<Tensor, Tensor> compute(
307298

308299
THO_DISPATCH_V2(logProbs.scalar_type(), "forced_align_impl", AT_WRAP([&] {
309300
if (targets.scalar_type() == ScalarType::Long) {
310-
forced_align_long_impl<scalar_t>(logProbs, targets, blank, paths);
301+
(forced_align_impl<scalar_t, ScalarType::Long>(logProbs, targets, blank, paths));
311302
} else {
312-
forced_align_int_impl<scalar_t>(logProbs, targets, blank, paths);
313-
}
303+
(forced_align_impl<scalar_t, ScalarType::Int>(logProbs, targets, blank, paths));
304+
}
314305
}), AT_EXPAND(AT_FLOATING_TYPES), ScalarType::Half);
315-
316306
Tensor pathsCuda = torchaudio::stable::cuda(paths, logProbs.get_device_index());
317307
return std::make_tuple(pathsCuda, logProbs);
318308
}
319309

320-
void boxed_forced_align_gpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
321-
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
322-
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
323-
std::tuple<Tensor, Tensor> res = compute(
324-
/*logProbs*/torch::stable::detail::to<Tensor>(stack[0]),
325-
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
326-
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
327-
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
328-
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])));
329-
stack[0] = torch::stable::detail::from(std::get<0>(res));
330-
stack[1] = torch::stable::detail::from(std::get<1>(res));
331-
}
332-
333310
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
334-
m.impl("forced_align", &boxed_forced_align_gpu);
311+
m.impl("forced_align", TORCH_BOX(&compute));
335312
}
336313

337314
} // namespace gpu

src/libtorchaudio/overdrive.cpp

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,20 @@
1+
#include <libtorchaudio/utils.h>
12
#include <torch/csrc/stable/library.h>
23
#include <torch/csrc/stable/ops.h>
34
#include <torch/csrc/stable/tensor.h>
45
#include <torch/headeronly/core/Dispatch_v2.h>
56
#include <torch/headeronly/core/TensorAccessor.h>
67

78
namespace {
8-
99
using torch::stable::Tensor;
1010

11-
template <typename T, size_t N>
12-
using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
13-
14-
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
15-
// after Tensor::accessor is supported in stable ABI
16-
template <typename T, size_t N>
17-
inline TensorAccessor<T, N> accessor(Tensor t) {
18-
return TensorAccessor<T, N>(
19-
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
20-
}
21-
2211
template <typename scalar_t>
2312
void overdrive_cpu_kernel(
24-
TensorAccessor<scalar_t, 2> waveform_accessor,
25-
TensorAccessor<scalar_t, 2> temp_accessor,
26-
TensorAccessor<scalar_t, 1> last_in_accessor,
27-
TensorAccessor<scalar_t, 1> last_out_accessor,
28-
TensorAccessor<scalar_t, 2> output_waveform_accessor) {
13+
torchaudio::TensorAccessor<scalar_t, 2> waveform_accessor,
14+
torchaudio::TensorAccessor<scalar_t, 2> temp_accessor,
15+
torchaudio::TensorAccessor<scalar_t, 1> last_in_accessor,
16+
torchaudio::TensorAccessor<scalar_t, 1> last_out_accessor,
17+
torchaudio::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
2918
int64_t n_frames = waveform_accessor.size(1);
3019
int64_t n_channels = waveform_accessor.size(0);
3120

@@ -56,11 +45,11 @@ std::tuple<Tensor, Tensor, Tensor> overdrive_core_loop_cpu(
5645
"overdrive_cpu",
5746
AT_WRAP([&] {
5847
overdrive_cpu_kernel<scalar_t>(
59-
accessor<scalar_t, 2>(waveform),
60-
accessor<scalar_t, 2>(temp),
61-
accessor<scalar_t, 1>(last_in),
62-
accessor<scalar_t, 1>(last_out),
63-
accessor<scalar_t, 2>(output_waveform));
48+
torchaudio::accessor<scalar_t, 2>(waveform),
49+
torchaudio::accessor<scalar_t, 2>(temp),
50+
torchaudio::accessor<scalar_t, 1>(last_in),
51+
torchaudio::accessor<scalar_t, 1>(last_out),
52+
torchaudio::accessor<scalar_t, 2>(output_waveform));
6453
}),
6554
AT_FLOATING_TYPES);
6655
return std::make_tuple(last_in, last_out, output_waveform);
@@ -70,7 +59,11 @@ std::tuple<Tensor, Tensor, Tensor> overdrive_core_loop_cpu(
7059

7160
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
7261
m.def(
73-
"_overdrive_core_loop(Tensor waveform, Tensor temp, Tensor(a!) last_in, Tensor(b!) last_out, Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))");
62+
"_overdrive_core_loop(Tensor waveform,"
63+
"Tensor temp,"
64+
"Tensor(a!) last_in,"
65+
"Tensor(b!) last_out,"
66+
"Tensor(c!) output_waveform) -> (Tensor(a!), Tensor(b!), Tensor(c!))");
7467
}
7568

7669
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {

0 commit comments

Comments
 (0)