Skip to content

Commit 8bde9ce

Browse files
committed
Collect a torch::stable wishlist in src/libtorchaudio/stable
1 parent ee07051 commit 8bde9ce

File tree

4 files changed

+72
-147
lines changed

4 files changed

+72
-147
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
#include <libtorchaudio/stable/dispatch.h>
2+
#include <libtorchaudio/stable/ops.h>
13
#include <libtorchaudio/utils.h>
24
#include <torch/csrc/stable/library.h>
35

46
namespace torchaudio {
57
namespace alignment {
68
namespace cpu {
79

8-
using torch::stable::Tensor;
910
using torch::headeronly::ScalarType;
11+
using torch::stable::Tensor;
1012

1113
// Inspired from
1214
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
@@ -34,17 +36,16 @@ void forced_align_impl(
3436
for (int i = 0; i < T * S; i++) {
3537
backPtr_a[i] = -1;
3638
}
37-
38-
auto logProbs_a = logProbs.accessor<scalar_t, 3>();
39-
auto targets_a = targets.accessor<target_t, 2>();
40-
auto paths_a = paths.accessor<target_t, 2>();
39+
auto logProbs_a = torchaudio::stable::accessor<scalar_t, 3>(logProbs);
40+
auto targets_a = torchaudio::stable::accessor<target_t, 2>(targets);
41+
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
4142
auto R = 0;
4243
for (auto i = 1; i < L; i++) {
4344
if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) {
4445
++R;
4546
}
4647
}
47-
TORCH_CHECK(
48+
STD_TORCH_CHECK(
4849
T >= L + R,
4950
"targets length is too long for CTC. Found log_probs length: ",
5051
T,
@@ -145,14 +146,16 @@ std::tuple<Tensor, Tensor> compute(
145146
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
146147
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
147148
STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor");
148-
STD_TORCH_CHECK(targetLengths.is_cpu(), "target_lengths must be a CPU tensor");
149+
STD_TORCH_CHECK(
150+
targetLengths.is_cpu(), "target_lengths must be a CPU tensor");
149151
STD_TORCH_CHECK(
150152
logProbs.scalar_type() == ScalarType::Double ||
151153
logProbs.scalar_type() == ScalarType::Float ||
152154
logProbs.scalar_type() == ScalarType::Half,
153155
"log_probs must be float64, float32 or float16 (half) type");
154156
STD_TORCH_CHECK(
155-
targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long,
157+
targets.scalar_type() == ScalarType::Int ||
158+
targets.scalar_type() == ScalarType::Long,
156159
"targets must be int32 or int64 type");
157160
STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
158161
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
@@ -184,57 +187,33 @@ std::tuple<Tensor, Tensor> compute(
184187

185188
const auto B = logProbs.size(0);
186189
const auto T = logProbs.size(1);
187-
Tensor paths = torch::stable::new_empty(targets, {B, T});
188-
torch::stable::zero_(paths);
189-
190-
switch (logProbs.scalar_type()) {
191-
case ScalarType::Double: {
192-
if (targets.scalar_type() == ScalarType::Long) {
193-
forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths);
194-
} else if (targets.scalar_type() == ScalarType::Int) {
195-
forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths);
196-
} else {
197-
STD_TORCH_CHECK(false, "unreachable");
198-
}
199-
break;
200-
}
201-
case ScalarType::Float: {
202-
if (targets.scalar_type() == ScalarType::Long) {
203-
forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths);
204-
} else if (targets.scalar_type() == ScalarType::Int) {
205-
forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths);
206-
} else {
207-
STD_TORCH_CHECK(false, "unreachable");
208-
}
209-
break;
210-
}
211-
case ScalarType::Half: {
212-
if (targets.scalar_type() == ScalarType::Long) {
213-
forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths);
214-
} else if (targets.scalar_type() == ScalarType::Int) {
215-
forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths);
216-
} else {
217-
STD_TORCH_CHECK(false, "unreachable");
218-
}
219-
break;
220-
}
221-
default: {
222-
STD_TORCH_CHECK(false, "unreachable");
223-
}
224-
};
225-
190+
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T});
191+
192+
STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(
193+
logProbs.scalar_type(), "forced_align_impl", [&] {
194+
if (targets.scalar_type() == ScalarType::Long) {
195+
forced_align_impl<scalar_t, ScalarType::Long>(
196+
logProbs, targets, blank, paths);
197+
} else {
198+
forced_align_impl<scalar_t, ScalarType::Int>(
199+
logProbs, targets, blank, paths);
200+
}
201+
});
226202
return std::make_tuple(paths, logProbs);
227203
}
228204

229-
void boxed_forced_align_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
205+
void boxed_forced_align_cpu(
206+
StableIValue* stack,
207+
uint64_t num_args,
208+
uint64_t num_outputs) {
230209
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
231210
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
232211
std::tuple<Tensor, Tensor> res = compute(
233-
/*logProbs*/to<Tensor>(stack[0]),
234-
/*targets*/to<Tensor>(stack[1]),
235-
/*logit_lengths*/to<Tensor>(stack[2]),
236-
/*target_lengths*/to<Tensor>(stack[3]),
237-
/*blank*/float(to<int64_t>(stack[4])));
212+
/*logProbs*/ to<Tensor>(stack[0]),
213+
/*targets*/ to<Tensor>(stack[1]),
214+
/*logit_lengths*/ to<Tensor>(stack[2]),
215+
/*target_lengths*/ to<Tensor>(stack[3]),
216+
/*blank*/ float(to<int64_t>(stack[4])));
238217
stack[0] = from(std::get<0>(res));
239218
stack[1] = from(std::get<1>(res));
240219
}

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 30 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include <libtorchaudio/utils.h>
2+
#include <libtorchaudio/stable/TensorAccessor.h>
3+
#include <libtorchaudio/stable/dispatch.h>
24
#include <torch/csrc/stable/library.h>
35

46
#include <cub/cub.cuh>
@@ -20,9 +22,9 @@ using torch::headeronly::ScalarType;
2022

2123
template <typename scalar_t, typename target_t>
2224
__global__ void falign_cuda_step_kernel(
23-
const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits>
25+
const torchaudio::stable::PackedTensorAccessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>
2426
logProbs_a,
25-
const at::PackedTensorAccessor32<target_t, 2, at::RestrictPtrTraits>
27+
const torchaudio::stable::PackedTensorAccessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>
2628
targets_a,
2729
const int T,
2830
const int L,
@@ -33,9 +35,9 @@ __global__ void falign_cuda_step_kernel(
3335
int start,
3436
int end,
3537
int backPtrBufferLen,
36-
at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
38+
torchaudio::stable::PackedTensorAccessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>
3739
alphas_a,
38-
at::PackedTensorAccessor32<int8_t, 2, at::RestrictPtrTraits>
40+
torchaudio::stable::PackedTensorAccessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>
3941
backPtrBuffer_a) {
4042
scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
4143
const int batchIndex =
@@ -122,15 +124,15 @@ void forced_align_impl(
122124
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
123125
using target_t = typename std::
124126
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
125-
auto paths_a = paths.accessor<target_t, 2>();
127+
auto paths_a = torchaudio::stable::accessor<target_t, 2>(paths);
126128
const int batchIndex =
127129
0; // TODO: support batch version and use the real batch index
128130
const int T = logProbs.size(1); // num frames
129131
const int N = logProbs.size(2); // alphabet size
130132
const int L = targets.size(1); // label length
131133
const int S = 2 * L + 1;
132134

133-
auto targetsCpu = torch::stable::cpu(targets);
135+
auto targetsCpu = torchaudio::stable::cpu(targets);
134136
// backPtrBuffer stores the index offset fthe best path at current position
135137
// We copy the values to CPU after running every kBackPtrBufferSize of
136138
// frames.
@@ -147,8 +149,8 @@ void forced_align_impl(
147149
torch::stable::fill_(alphas, kNegInfinity);
148150

149151
// CPU accessors
150-
auto targetsCpu_a = targetsCpu.accessor<target_t, 2>();
151-
auto backPtrCpu_a = backPtrCpu.accessor<int8_t, 2>();
152+
auto targetsCpu_a = torchaudio::stable::accessor<target_t, 2>(targetsCpu);
153+
auto backPtrCpu_a = torchaudio::stable::accessor<int8_t, 2>(backPtrCpu);
152154
// count the number of repeats in label
153155
int R = 0;
154156
for (int i = 1; i < L; ++i) {
@@ -189,8 +191,8 @@ void forced_align_impl(
189191
}
190192
falign_cuda_step_kernel<scalar_t, target_t>
191193
<<<1, kNumThreads, 0, defaultStream>>>(
192-
logProbs.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(),
193-
targets.packed_accessor32<target_t, 2, at::RestrictPtrTraits>(),
194+
torchaudio::stable::packed_accessor32<scalar_t, 3, torchaudio::stable::RestrictPtrTraits>(logProbs),
195+
torchaudio::stable::packed_accessor32<target_t, 2, torchaudio::stable::RestrictPtrTraits>(targets),
194196
T,
195197
L,
196198
N,
@@ -200,15 +202,14 @@ void forced_align_impl(
200202
start,
201203
end,
202204
backPtrBufferLen,
203-
alphas.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(),
204-
backPtrBuffer
205-
.packed_accessor32<int8_t, 2, at::RestrictPtrTraits>());
205+
torchaudio::stable::packed_accessor32<scalar_t, 2, torchaudio::stable::RestrictPtrTraits>(alphas),
206+
torchaudio::stable::packed_accessor32<int8_t, 2, torchaudio::stable::RestrictPtrTraits>(backPtrBuffer));
206207
C10_CUDA_KERNEL_LAUNCH_CHECK();
207208
++backPtrBufferLen;
208209
if (backPtrBufferLen == kBackPtrBufferSize || t == T - 1) {
209210
cpuDataTranferStream.synchronize();
210211
// GPU -> GPU copy
211-
bufferCopy = backPtrBuffer.clone();
212+
bufferCopy = torchaudio::stable::clone(backPtrBuffer);
212213
STD_TORCH_CHECK(bufferCopy.is_contiguous(), "unexpected fail, need to implement stable::Tensor::contiguous()")
213214
defaultStream.synchronize();
214215
at::cuda::setCurrentCUDAStream(cpuDataTranferStream);
@@ -227,8 +228,8 @@ void forced_align_impl(
227228
}
228229
cpuDataTranferStream.synchronize();
229230

230-
auto alphasCpu = torch::stable::cpu(alphas);
231-
auto alphasCpu_a = alphasCpu.accessor<scalar_t, 2>();
231+
auto alphasCpu = torchaudio::stable::cpu(alphas);
232+
auto alphasCpu_a = torchaudio::stable::accessor<scalar_t, 2>(alphasCpu);
232233
int curIdxOffset = ((T - 1) % 2);
233234
int ltrIdx =
234235
alphasCpu_a[curIdxOffset][S - 1] > alphasCpu_a[curIdxOffset][S - 2]
@@ -294,50 +295,20 @@ std::tuple<Tensor, Tensor> compute(
294295
auto B = logProbs.size(0);
295296
auto T = logProbs.size(1); // num frames
296297

297-
Tensor paths = torch::stable::new_empty(targets, {B, T}, std::nullopt, aoti_torch_device_type_cpu());
298-
torch::stable::zero_(paths);
298+
Tensor paths = torchaudio::stable::new_zeros(targets, {B, T}, /*dtype=*/std::nullopt, /*layout=*/std::nullopt, /*device=*/torchaudio::stable::cpu_device());
299299

300-
switch (logProbs.scalar_type()) {
301-
case ScalarType::Double: {
302-
if (targets.scalar_type() == ScalarType::Long) {
303-
forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths);
304-
} else if (targets.scalar_type() == ScalarType::Int) {
305-
forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths);
306-
} else {
307-
STD_TORCH_CHECK(false, "unreachable");
308-
}
309-
break;
310-
}
311-
case ScalarType::Float: {
312-
if (targets.scalar_type() == ScalarType::Long) {
313-
forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths);
314-
} else if (targets.scalar_type() == ScalarType::Int) {
315-
forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths);
316-
} else {
317-
STD_TORCH_CHECK(false, "unreachable");
318-
}
319-
break;
320-
}
321-
case ScalarType::Half: {
322-
if (targets.scalar_type() == ScalarType::Long) {
323-
forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths);
324-
} else if (targets.scalar_type() == ScalarType::Int) {
325-
forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths);
326-
} else {
327-
STD_TORCH_CHECK(false, "unreachable");
328-
}
329-
break;
330-
}
331-
default: {
332-
STD_TORCH_CHECK(false, "unreachable");
333-
}
334-
};
335-
Tensor pathsCuda = torch::stable::new_empty(paths,
336-
torchaudio::util::sizes(paths),
337-
std::nullopt,
338-
aoti_torch_device_type_cuda(),
339-
logProbs.get_device_index());
340-
torch::stable::copy_(pathsCuda, paths);
300+
STABLE_DISPATCH_FLOATING_TYPES_AND_HALF(
301+
logProbs.scalar_type(), "forced_align_impl", [&] {
302+
if (targets.scalar_type() == ScalarType::Long) {
303+
forced_align_impl<scalar_t, ScalarType::Long>(
304+
logProbs, targets, blank, paths);
305+
} else {
306+
forced_align_impl<scalar_t, ScalarType::Int>(
307+
logProbs, targets, blank, paths);
308+
}
309+
});
310+
311+
Tensor pathsCuda = torchaudio::stable::cuda(paths, logProbs.get_device_index());
341312
return std::make_tuple(pathsCuda, logProbs);
342313
}
343314

src/libtorchaudio/utils.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#include <ATen/DynamicLibrary.h>
2-
#include <libtorchaudio/utils.h>
3-
42
#include <torch/csrc/stable/tensor.h>
3+
#include <libtorchaudio/utils.h>
54

65
#ifdef USE_CUDA
76
#include <cuda.h>

src/libtorchaudio/utils.h

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,21 @@
11
#pragma once
2-
#include <torch/csrc/stable/tensor_struct.h>
3-
#include <torch/csrc/stable/ops.h>
42

5-
#ifdef USE_CUDA
6-
#include <ATen/cuda/CUDAContext.h>
7-
#include <c10/cuda/CUDAException.h>
8-
#endif
3+
// TODO: replace the include libtorchaudio/stable/ops.h with
4+
// torch/stable/ops.h when torch::stable provides all required
5+
// features (torch::stable::item<T> or similar):
6+
#include <libtorchaudio/stable/ops.h>
97

108
namespace torchaudio {
119

1210
namespace util {
13-
inline std::vector<int64_t> sizes(const torch::stable::Tensor& t) {
14-
auto sizes_ = t.sizes();
15-
std::vector<int64_t> sizes(sizes_.data(), sizes_.data() + t.dim());
16-
return sizes;
17-
}
18-
19-
template <typename T>
20-
T item(const torch::stable::Tensor& t) {
21-
STD_TORCH_CHECK(t.numel() == 1, "item requires single element tensor input");
22-
if (t.is_cpu()) {
23-
return t.const_data_ptr<T>()[0];
24-
#ifdef USE_CUDA
25-
} else if (t.is_cuda()) {
26-
T value;
27-
C10_CUDA_CHECK(cudaMemcpyAsync(&value, t.data_ptr(), sizeof(T), cudaMemcpyDeviceToHost));
28-
return value;
29-
#endif
30-
} else {
31-
STD_TORCH_CHECK(false, "unreachable");
32-
}
33-
}
34-
35-
template <typename T>
36-
T max(const torch::stable::Tensor& t) {
37-
// TODO: eliminate const_cast after pytorch/pytorch#161826 is fixed
38-
return item<T>(torch::stable::amax(const_cast<torch::stable::Tensor&>(t), {}));
39-
}
11+
template <typename T>
12+
T max(const torch::stable::Tensor& t) {
13+
return torchaudio::stable::item<T>(torch::stable::amax(t, {}));
4014
}
15+
} // namespace util
4116

4217
bool is_rir_available();
4318
bool is_align_available();
4419
std::optional<int64_t> cuda_version();
20+
4521
} // namespace torchaudio

0 commit comments

Comments
 (0)