Skip to content

Commit 0c22347

Browse files
Use narrow instead of indexing by slices (#4014)
* Use narrow instead of indexing by slices * Use index_select instead of select * Move fancy indexing to python instead of c++ --------- Co-authored-by: Sam Anklesaria <[email protected]>
1 parent 6fbc710 commit 0c22347

File tree

4 files changed

+9
-19
lines changed

4 files changed

+9
-19
lines changed

src/libtorchaudio/forced_align/cpu/compute.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,13 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
187187
});
188188
return std::make_tuple(
189189
paths,
190-
logProbs.index(
191-
{torch::indexing::Slice(),
192-
torch::linspace(
193-
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
194-
paths.index({0})}));
190+
logProbs
191+
);
195192
}
196193

194+
195+
196+
197197
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
198198
m.impl("forced_align", &compute);
199199
}

src/libtorchaudio/forced_align/gpu/compute.cu

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -306,11 +306,7 @@ std::tuple<torch::Tensor, torch::Tensor> compute(
306306
});
307307
return std::make_tuple(
308308
paths.to(logProbs.device()),
309-
logProbs.index(
310-
{torch::indexing::Slice(),
311-
torch::linspace(
312-
0, T - 1, T, torch::TensorOptions().dtype(paths.dtype())),
313-
paths.index({0})}));
309+
logProbs);
314310
}
315311

316312
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {

src/libtorchaudio/lfilter.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,9 @@ void lfilter_core_generic_loop(
8282
auto coeff = a_coeff_flipped.unsqueeze(2);
8383
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
8484
auto windowed_output_signal =
85-
padded_output_waveform
86-
.index(
87-
{torch::indexing::Slice(),
88-
torch::indexing::Slice(),
89-
torch::indexing::Slice(i_sample, i_sample + n_order)})
90-
.transpose(0, 1);
85+
torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order).transpose(0, 1);
9186
auto o0 =
92-
input_signal_windows.index(
93-
{torch::indexing::Slice(), torch::indexing::Slice(), i_sample}) -
87+
torch::select(input_signal_windows, 2, i_sample) -
9488
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
9589
padded_output_waveform.index_put_(
9690
{torch::indexing::Slice(),

src/torchaudio/functional/_alignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def forced_align(
7070
assert target_lengths is not None
7171

7272
paths, scores = torch.ops.torchaudio.forced_align(log_probs, targets, input_lengths, target_lengths, blank)
73-
return paths, scores
73+
return paths, scores[:, torch.arange(scores.shape[1]), paths[0]]
7474

7575

7676
@dataclass

0 commit comments

Comments
 (0)