Skip to content

Commit 258ca00

Browse files
committed
Merge branch 'main' into forced_align_accessors
2 parents be13f64 + 0c22347 commit 258ca00

File tree

4 files changed

+7
-20
lines changed

4 files changed

+7
-20
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ std::tuple<Tensor, Tensor> compute(
191191
auto paths = Tensor(paths_h);
192192

193193

194-
if (targets.scalar_type() == aoti_torch_dtype_int64()) {
194+
if (targets.dtype() == aoti_torch_dtype_int64()) {
195195
if (logProbs.scalar_type() == aoti_torch_dtype_float64()) {
196196
forced_align_impl<float64, int64>(logProbs, targets, blank, paths);
197197
} else if (logProbs.scalar_type() == aoti_torch_dtype_float32()) {
@@ -210,11 +210,8 @@ std::tuple<Tensor, Tensor> compute(
210210
}
211211
return std::make_tuple(
212212
paths,
213-
logProbs.index(
214-
{torch::indexing::Slice(),
215-
torch::linspace(
216-
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
217-
paths.index({0})}));
213+
logProbs
214+
);
218215
}
219216

220217

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,7 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
306306
});
307307
return std::make_tuple(
308308
paths.to(logProbs.device()),
309-
logProbs.index(
310-
{torch::indexing::Slice(),
311-
torch::linspace(
312-
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
313-
paths.index({0})}));
309+
logProbs);
314310
}
315311

316312
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {

src/libtorchaudio/lfilter.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,9 @@ void lfilter_core_generic_loop(
8282
auto coeff = a_coeff_flipped.unsqueeze(2);
8383
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
8484
auto windowed_output_signal =
85-
padded_output_waveform
86-
.index(
87-
{torch::indexing::Slice(),
88-
torch::indexing::Slice(),
89-
torch::indexing::Slice(i_sample, i_sample + n_order)})
90-
.transpose(0, 1);
85+
torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order).transpose(0, 1);
9186
auto o0 =
92-
input_signal_windows.index(
93-
{torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) -
87+
torch::select(input_signal_windows, 2, i_sample) -
9488
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
9589
padded_output_waveform.index_put_(
9690
{torch::indexing::Slice(),

src/torchaudio/functional/_alignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def forced_align(
7070
assert target_lengths is not None
7171

7272
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
73-
return paths, scores
73+
return paths, scores[:, torch.arange(scores.shape[1]), paths[0]]
7474

7575

7676
@dataclass

0 commit comments

Comments
 (0)