Skip to content

Commit 48081cf

Browse files
committed
[STABLE ABI] Port forced_align
1 parent 87ff22e commit 48081cf

File tree

4 files changed

+261
-145
lines changed

4 files changed

+261
-145
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 92 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,24 @@
1-
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
1+
#include <libtorchaudio/utils.h>
22
#include <torch/csrc/stable/library.h>
3-
#include <torch/csrc/stable/ops.h>
4-
#include <torch/csrc/stable/tensor.h>
5-
#include <torch/script.h>
6-
#include <torch/torch.h>
7-
8-
using namespace std;
93

104
namespace torchaudio {
115
namespace alignment {
126
namespace cpu {
7+
8+
using torch::stable::Tensor;
9+
using torch::headeronly::ScalarType;
10+
1311
// Inspired from
1412
// https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp
15-
template <typename scalar_t, at::ScalarType target_scalar_type>
13+
template <typename scalar_t, ScalarType target_scalar_type>
1614
void forced_align_impl(
17-
const torch::Tensor& logProbs,
18-
const torch::Tensor& targets,
15+
const Tensor& logProbs,
16+
const Tensor& targets,
1917
const int64_t blank,
20-
torch::Tensor& paths) {
18+
Tensor& paths) {
2119
const scalar_t kNegInfinity = -std::numeric_limits<scalar_t>::infinity();
2220
using target_t = typename std::
23-
conditional<target_scalar_type == torch::kInt, int, int64_t>::type;
21+
conditional<target_scalar_type == ScalarType::Int, int, int64_t>::type;
2422
const auto batchIndex =
2523
0; // TODO: support batch version and use the real batch index
2624
const auto T = logProbs.size(1);
@@ -138,73 +136,111 @@ void forced_align_impl(
138136
delete[] backPtr_a;
139137
}
140138

141-
std::tuple<torch::Tensor, torch::Tensor> compute(
142-
const torch::Tensor& logProbs,
143-
const torch::Tensor& targets,
144-
const torch::Tensor& inputLengths,
145-
const torch::Tensor& targetLengths,
139+
std::tuple<Tensor, Tensor> compute(
140+
const Tensor& logProbs,
141+
const Tensor& targets,
142+
const Tensor& inputLengths,
143+
const Tensor& targetLengths,
146144
const int64_t blank) {
147-
TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
148-
TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
149-
TORCH_CHECK(
150-
logProbs.device() == targets.device(),
151-
"log_probs and targets need to be on the same device");
152-
TORCH_CHECK(
153-
logProbs.dtype() == torch::kFloat64 ||
154-
logProbs.dtype() == torch::kFloat32 ||
155-
logProbs.dtype() == torch::kFloat16,
145+
STD_TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor");
146+
STD_TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor");
147+
STD_TORCH_CHECK(inputLengths.is_cpu(), "input_lengths must be a CPU tensor");
148+
STD_TORCH_CHECK(targetLengths.is_cpu(), "target_lengths must be a CPU tensor");
149+
STD_TORCH_CHECK(
150+
logProbs.scalar_type() == ScalarType::Double ||
151+
logProbs.scalar_type() == ScalarType::Float ||
152+
logProbs.scalar_type() == ScalarType::Half,
156153
"log_probs must be float64, float32 or float16 (half) type");
157-
TORCH_CHECK(
158-
targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64,
154+
STD_TORCH_CHECK(
155+
targets.scalar_type() == ScalarType::Int || targets.scalar_type() == ScalarType::Long,
159156
"targets must be int32 or int64 type");
160-
TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
161-
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
162-
TORCH_CHECK(
157+
STD_TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous");
158+
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
159+
STD_TORCH_CHECK(
163160
logProbs.dim() == 3,
164161
"log_probs must be 3-D (batch_size, input length, num classes)");
165-
TORCH_CHECK(
162+
STD_TORCH_CHECK(
166163
targets.dim() == 2, "targets must be 2-D (batch_size, target length,)");
167-
TORCH_CHECK(
164+
STD_TORCH_CHECK(
168165
inputLengths.dim() == 1, "input_lengths must be 1-D (batch_size,)");
169-
TORCH_CHECK(
166+
STD_TORCH_CHECK(
170167
targetLengths.dim() == 1, "target_lengths must be 1-D (batch_size,)");
171-
TORCH_CHECK(
168+
STD_TORCH_CHECK(
172169
logProbs.size(0) == 1,
173170
"The batch dimension for log_probs must be 1 at the current version.")
174-
TORCH_CHECK(
171+
STD_TORCH_CHECK(
175172
targets.size(0) == 1,
176173
"The batch dimension for targets must be 1 at the current version.")
177-
TORCH_CHECK(
174+
STD_TORCH_CHECK(
178175
blank >= 0 && blank < logProbs.size(-1),
179176
"blank must be within [0, num classes)");
180177

181-
TORCH_CHECK(
182-
logProbs.size(1) == at::max(inputLengths).item().toInt(),
178+
STD_TORCH_CHECK(
179+
logProbs.size(1) == torchaudio::util::max<int>(inputLengths),
183180
"input length mismatch");
184-
TORCH_CHECK(
185-
targets.size(1) == at::max(targetLengths).item().toInt(),
181+
STD_TORCH_CHECK(
182+
targets.size(1) == torchaudio::util::max<int>(targetLengths),
186183
"target length mismatch");
187184

188185
const auto B = logProbs.size(0);
189186
const auto T = logProbs.size(1);
190-
auto paths = torch::zeros(
191-
{B, T},
192-
torch::TensorOptions().device(targets.device()).dtype(targets.dtype()));
193-
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
194-
logProbs.scalar_type(), "forced_align_impl", [&] {
195-
if (targets.scalar_type() == torch::kInt64) {
196-
forced_align_impl<scalar_t, torch::kInt64>(
197-
logProbs, targets, blank, paths);
198-
} else {
199-
forced_align_impl<scalar_t, torch::kInt32>(
200-
logProbs, targets, blank, paths);
201-
}
202-
});
187+
Tensor paths = torch::stable::new_empty(targets, {B, T});
188+
torch::stable::zero_(paths);
189+
190+
switch (logProbs.scalar_type()) {
191+
case ScalarType::Double: {
192+
if (targets.scalar_type() == ScalarType::Long) {
193+
forced_align_impl<double, ScalarType::Long>(logProbs, targets, blank, paths);
194+
} else if (targets.scalar_type() == ScalarType::Int) {
195+
forced_align_impl<double, ScalarType::Int>(logProbs, targets, blank, paths);
196+
} else {
197+
STD_TORCH_CHECK(false, "unreachable");
198+
}
199+
break;
200+
}
201+
case ScalarType::Float: {
202+
if (targets.scalar_type() == ScalarType::Long) {
203+
forced_align_impl<float, ScalarType::Long>(logProbs, targets, blank, paths);
204+
} else if (targets.scalar_type() == ScalarType::Int) {
205+
forced_align_impl<float, ScalarType::Int>(logProbs, targets, blank, paths);
206+
} else {
207+
STD_TORCH_CHECK(false, "unreachable");
208+
}
209+
break;
210+
}
211+
case ScalarType::Half: {
212+
if (targets.scalar_type() == ScalarType::Long) {
213+
forced_align_impl<c10::Half, ScalarType::Long>(logProbs, targets, blank, paths);
214+
} else if (targets.scalar_type() == ScalarType::Int) {
215+
forced_align_impl<c10::Half, ScalarType::Int>(logProbs, targets, blank, paths);
216+
} else {
217+
STD_TORCH_CHECK(false, "unreachable");
218+
}
219+
break;
220+
}
221+
default: {
222+
STD_TORCH_CHECK(false, "unreachable");
223+
}
224+
};
225+
203226
return std::make_tuple(paths, logProbs);
204227
}
205228

206-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
207-
m.impl("forced_align", &compute);
229+
void boxed_forced_align_cpu(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
230+
STD_TORCH_CHECK(num_args == 5, "num_args must be 5");
231+
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
232+
std::tuple<Tensor, Tensor> res = compute(
233+
/*logProbs*/to<Tensor>(stack[0]),
234+
/*targets*/to<Tensor>(stack[1]),
235+
/*logit_lengths*/to<Tensor>(stack[2]),
236+
/*target_lengths*/to<Tensor>(stack[3]),
237+
/*blank*/float(to<int64_t>(stack[4])));
238+
stack[0] = from(std::get<0>(res));
239+
stack[1] = from(std::get<1>(res));
240+
}
241+
242+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
243+
m.impl("forced_align", &boxed_forced_align_cpu);
208244
}
209245

210246
} // namespace cpu

0 commit comments

Comments
 (0)