Skip to content

Commit c2546fc

Browse files
committed
Update forced_align tensor accessors to use headeronly.
1 parent ee1a135 commit c2546fc

File tree

4 files changed

+50
-52
lines changed

4 files changed

+50
-52
lines changed
Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,6 @@
1-
#include <libtorchaudio/forced_align/compute.h>
2-
#include <torch/script.h>
1+
#include <torch/csrc/stable/library.h>
32

4-
std::tuple<torch::Tensor, torch::Tensor> forced_align(
5-
const torch::Tensor& logProbs,
6-
const torch::Tensor& targets,
7-
const torch::Tensor& inputLengths,
8-
const torch::Tensor& targetLengths,
9-
const int64_t blank) {
10-
static auto op = torch::Dispatcher::singleton()
11-
.findSchemaOrThrow("torchaudio::forced_align", "")
12-
.typed<decltype(forced_align)>();
13-
return op.call(logProbs, targets, inputLengths, targetLengths, blank);
14-
}
15-
16-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
174
m.def(
185
"forced_align(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> (Tensor, Tensor)");
196
}
Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1 @@
11
#pragma once
2-
3-
#include <torch/script.h>
4-
5-
std::tuple<torch::Tensor, torch::Tensor> forced_align(
6-
const torch::Tensor& logProbs,
7-
const torch::Tensor& targets,
8-
const torch::Tensor& inputLengths,
9-
const torch::Tensor& targetLengths,
10-
const int64_t blank);

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <libtorchaudio/utils.h>
2-
#include <libtorchaudio/stable/TensorAccessor.h>
32
#include <torch/csrc/stable/library.h>
43
#include <torch/headeronly/core/Dispatch_v2.h>
54
#include <torch/headeronly/core/ScalarType.h>
@@ -23,9 +22,9 @@ using torch::headeronly::ScalarType;
2322

2423
template <typename scalar_t, typename target_t>
2524
__global__ void falign_cuda_step_kernel(
26-
const torchaudio::stable::PackedTensorAccessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>
25+
const PackedTensorAccessor32<scalar_t, 3>
2726
logProbs_a,
28-
const torchaudio::stable::PackedTensorAccessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>
27+
const PackedTensorAccessor32<target_t, 2>
2928
targets_a,
3029
const int T,
3130
const int L,
@@ -36,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
3635
int start,
3736
int end,
3837
int backPtrBufferLen,
39-
torchaudio::stable::PackedTensorAccessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>
38+
PackedTensorAccessor32<scalar_t, 2>
4039
alphas_a,
41-
torchaudio::stable::PackedTensorAccessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>
40+
PackedTensorAccessor32<int8_t, 2>
4241
backPtrBuffer_a) {
4342
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
4443
const int batchIndex =
@@ -125,7 +124,7 @@ void forced_align_impl(
125124
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
126125
using target_t = typename std::
127126
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
128-
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
127+
auto paths_a = accessor<target_t, 2>(paths);
129128
const int batchIndex =
130129
0; // TODO: support batch version and use the real batch index
131130
const int T = logProbs.size(1); // num frames
@@ -192,8 +191,8 @@ void forced_align_impl(
192191
}
193192
falign_cuda_step_kernel<scalar_t, target_t>
194193
<<<1, kNumThreads, 0, defaultStream>>>(
195-
torchaudio::stable::packed_accessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>(logProbs),
196-
torchaudio::stable::packed_accessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>(targets),
194+
packed_accessor32<scalar_t, 3>(logProbs),
195+
packed_accessor32<target_t, 2>(targets),
197196
T,
198197
L,
199198
N,
@@ -203,12 +202,13 @@ void forced_align_impl(
203202
start,
204203
end,
205204
backPtrBufferLen,
206-
torchaudio::stable::packed_accessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>(alphas),
207-
torchaudio::stable::packed_accessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>(backPtrBuffer));
205+
packed_accessor32<scalar_t, 2>(alphas),
206+
packed_accessor32<int8_t, 2>(backPtrBuffer));
208207
C10_CUDA_KERNEL_LAUNCH_CHECK();
209208
++backPtrBufferLen;
210209
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
211210
cpuDataTranferStream.synchronize();
211+
212212
// GPU -> GPU copy
213213
bufferCopy = torch::stable::clone(backPtrBuffer);
214214
STD_TORCH_CHECK(bufferCopy.is_contiguous(), "unexpected fail, need to implement stable::Tensor::contiguous()")
@@ -228,7 +228,6 @@ void forced_align_impl(
228228
}
229229
}
230230
cpuDataTranferStream.synchronize();
231-
232231
auto alphasCpu = torchaudio::stable::cpu(alphas);
233232
auto alphasCpu_a = torchaudio::stable::accessor<scalar_t, 2>(alphasCpu);
234233
int curIdxOffset = ((T - 1) % 2);
@@ -252,10 +251,10 @@ template <typename scalar_t>
252251
const auto forced_align_int_impl = forced_align_impl<scalar_t, ScalarType::Int>;
253252

254253
std::tuple<Tensor, Tensor> compute(
255-
const Tensor& logProbs,
256-
const Tensor& targets,
257-
const Tensor& inputLengths,
258-
const Tensor& targetLengths,
254+
Tensor logProbs,
255+
Tensor targets,
256+
Tensor inputLengths,
257+
Tensor targetLengths,
259258
const int64_t blank) {
260259

261260
STD_TORCH_CHECK(logProbs.is_cuda(), "log_probs must be a CUDA tensor");
@@ -317,21 +316,9 @@ std::tuple<Tensor, Tensor> compute(
317316
return std::make_tuple(pathsCuda, logProbs);
318317
}
319318

320-
void boxed_forced_align_gpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
321-
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
322-
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
323-
std::tuple<Tensor, Tensor> res = compute(
324-
/*logProbs*/torch::stable::detail::to<Tensor>(stack[0]),
325-
/*targets*/torch::stable::detail::to<Tensor>(stack[1]),
326-
/*logit_lengths*/torch::stable::detail::to<Tensor>(stack[2]),
327-
/*target_lengths*/torch::stable::detail::to<Tensor>(stack[3]),
328-
/*blank*/float(torch::stable::detail::to<int64_t>(stack[4])));
329-
stack[0] = torch::stable::detail::from(std::get<0>(res));
330-
stack[1] = torch::stable::detail::from(std::get<1>(res));
331-
}
332319

333320
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
334-
m.impl("forced_align", &boxed_forced_align_gpu);
321+
m.impl("forced_align", TORCH_BOX(&compute));
335322
}
336323

337324
} // namespace gpu

src/libtorchaudio/utils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include <torch/headeronly/core/TensorAccessor.h>
4+
35
// TODO: replace the include libtorchaudio/stable/ops.h with
46
// torch/stable/ops.h when torch::stable provides all required
57
// features (torch::stable::item<T> or similar):
@@ -17,4 +19,35 @@ T max(const torch::stable::Tensor& t) {
1719
bool is_align_available();
1820
std::optional<int64_t> cuda_version();
1921

22+
template <typename T, size_t N>
23+
using TensorAccessor = torch::headeronly::HeaderOnlyTensorAccessor<T, N>;
24+
25+
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
26+
// after Tensor::accessor is supported in stable ABI
27+
template <typename T, size_t N>
28+
inline TensorAccessor<T, N> accessor(Tensor t) {
29+
return TensorAccessor<T, N>(
30+
reinterpret_cast<T*>(t.data_ptr()), t.sizes().data(), t.strides().data());
31+
}
32+
33+
#if defined(__CUDACC__) || defined(__HIPCC__)
34+
template <typename T, size_t N>
35+
using PackedTensorAccessor32 =
36+
torch::headeronly::HeaderOnlyGenericPackedTensorAccessor<
37+
T,
38+
N,
39+
torch::headeronly::RestrictPtrTraits,
40+
int32_t>;
41+
42+
// TODO: eliminate accessor<T, N>(t) in favor of t.accessor<T, N>
43+
// after Tensor::accessor is supported in stable ABI
44+
template <typename T, size_t N>
45+
inline PackedTensorAccessor32<T, N> packed_accessor32(Tensor t) {
46+
return PackedTensorAccessor32<T, N>(
47+
static_cast<typename PackedTensorAccessor32<T, N>::PtrType>(t.data_ptr()),
48+
t.sizes().data(),
49+
t.strides().data());
50+
}
51+
#endif
52+
2053
} // namespace torchaudio

0 commit comments

Comments
 (0)