diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 85bc227cd6..20ad792b32 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -6,6 +6,7 @@ set( lfilter.cpp overdrive.cpp utils.cpp + accessor_tests.cpp ) set( diff --git a/src/libtorchaudio/accessor.h b/src/libtorchaudio/accessor.h new file mode 100644 index 0000000000..0fc23e978f --- /dev/null +++ b/src/libtorchaudio/accessor.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include +#include + +using torch::stable::Tensor; + +template +class Accessor { + int64_t strides[k]; + T *data; + +public: + using tensor_type = typename std::conditional::type; + + Accessor(tensor_type tensor) { + data = (T*)tensor.template data_ptr(); + for (unsigned int i = 0; i < k; i++) { + strides[i] = tensor.stride(i); + } + } + + T index(...) { + va_list args; + va_start(args, k); + int64_t ix = 0; + for (unsigned int i = 0; i < k; i++) { + ix += strides[i] * va_arg(args, int); + } + va_end(args); + return data[ix]; + } + + template + typename std::enable_if::type set_index(T value, ...) { + va_list args; + va_start(args, value); + int64_t ix = 0; + for (unsigned int i = 0; i < k; i++) { + ix += strides[i] * va_arg(args, int); + } + va_end(args); + data[ix] = value; + } +}; diff --git a/src/libtorchaudio/accessor_tests.cpp b/src/libtorchaudio/accessor_tests.cpp new file mode 100644 index 0000000000..62e9b23d5a --- /dev/null +++ b/src/libtorchaudio/accessor_tests.cpp @@ -0,0 +1,46 @@ +#include +#include +#include +#include +#include + +namespace torchaudio { + +namespace accessor_tests { + +using namespace std; +using torch::stable::Tensor; + +bool test_accessor(const Tensor tensor) { + int64_t* data_ptr = (int64_t*)tensor.data_ptr(); + auto accessor = Accessor<3, int64_t>(tensor); + for (unsigned int i = 0; i < tensor.size(0); i++) { + for (unsigned int j = 0; j < tensor.size(1); j++) { + for (unsigned int k = 0; k < tensor.size(2); k++) { + auto check = *(data_ptr++) == accessor.index(i, j, k); + if (!check) { + return false; + } + } + } + } + return true; +} + +void boxed_test_accessor(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + auto result = test_accessor(std::move(t1)); + stack[0] = from(result); +} + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "_test_accessor(Tensor log_probs) -> bool"); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("torchaudio::_test_accessor", &boxed_test_accessor); +} + +} +} diff --git a/src/libtorchaudio/forced_align/cpu/compute.cpp b/src/libtorchaudio/forced_align/cpu/compute.cpp index 828ce4829a..6314433786 100644 --- a/src/libtorchaudio/forced_align/cpu/compute.cpp +++ b/src/libtorchaudio/forced_align/cpu/compute.cpp @@ -1,42 +1,53 @@ #include #include +#include +#include +#include +#include +#include +#include +#include + using namespace std; namespace torchaudio { namespace alignment { namespace cpu { + +using torch::stable::Tensor; + // Inspired from // https://github.com/flashlight/sequence/blob/main/flashlight/lib/sequence/criterion/cpu/ConnectionistTemporalClassificationCriterion.cpp -template +template void forced_align_impl( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const int64_t blank, - torch::Tensor& paths) { + const Tensor logProbs, + const Tensor targets, + target_t blank, + Tensor paths) { const scalar_t kNegInfinity = -std::numeric_limits::infinity(); - using target_t = typename std:: - conditional::type; const auto batchIndex = 0; // TODO: support batch version and use the real batch index const auto T = logProbs.size(1); const auto L = targets.size(1); const auto S = 2 * L + 1; - torch::Tensor alphas = torch::empty( - {2, S}, - torch::TensorOptions() - .device(logProbs.device()) - .dtype(logProbs.dtype())) - .fill_(kNegInfinity); - torch::Tensor backPtr = torch::empty({T, S}, torch::kInt8).fill_(-1); - auto logProbs_a = logProbs.accessor(); - auto targets_a = targets.accessor(); - auto paths_a = paths.accessor(); - auto alphas_a = alphas.accessor(); - auto backPtr_a = backPtr.accessor(); + + auto alphas_a = new scalar_t[2 * S]; // scalar_t is just logProbs.dtype() + for (int i = 0; i < 2 * S; i++) { + alphas_a[i] = kNegInfinity; + } + + auto backPtr_a = new int8_t[T * S]; + for (int i = 0; i < T * S; i++) { + backPtr_a[i] = -1; + } + + auto logProbs_a = Accessor<3, scalar_t, true>(logProbs); + auto targets_a = Accessor<2, target_t, true>(targets); + auto paths_a = Accessor<2, target_t, false>(paths); auto R = 0; for (auto i = 1; i < L; i++) { - if (targets_a[batchIndex][i] == targets_a[batchIndex][i - 1]) { + if (targets_a.index(batchIndex, i) == targets_a.index(batchIndex, i - 1)) { ++R; } } @@ -51,22 +62,23 @@ void forced_align_impl( auto start = T - (L + R) > 0 ? 0 : 1; auto end = (S == 1) ? 1 : 2; for (auto i = start; i < end; i++) { - auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; - alphas_a[0][i] = logProbs_a[batchIndex][0][labelIdx]; + auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2); + alphas_a[i] = logProbs_a.index(batchIndex,0,labelIdx); + } for (auto t = 1; t < T; t++) { if (T - t <= L + R) { if ((start % 2 == 1) && - targets_a[batchIndex][start / 2] != - targets_a[batchIndex][start / 2 + 1]) { + targets_a.index(batchIndex, start / 2) != + targets_a.index(batchIndex, start / 2 + 1)) { start = start + 1; } start = start + 1; } if (t <= L + R) { if (end % 2 == 0 && end < 2 * L && - targets_a[batchIndex][end / 2 - 1] != - targets_a[batchIndex][end / 2]) { + targets_a.index(batchIndex, end / 2 - 1) != + targets_a.index(batchIndex, end / 2)) { end = end + 1; } end = end + 1; @@ -75,72 +87,76 @@ void forced_align_impl( auto curIdxOffset = t % 2; auto prevIdxOffset = (t - 1) % 2; for (auto j = 0; j < S; ++j) { - alphas_a[curIdxOffset][j] = -std::numeric_limits::infinity(); + alphas_a[curIdxOffset * S + j] = -std::numeric_limits::infinity(); // alphas_a[curIdxOffset][j] } if (start == 0) { - alphas_a[curIdxOffset][0] = - alphas_a[prevIdxOffset][0] + logProbs_a[batchIndex][t][blank]; - backPtr_a[t][0] = 0; + alphas_a[curIdxOffset * S] = + alphas_a[prevIdxOffset * S] + logProbs_a.index(batchIndex, t, blank); + backPtr_a[S * t] = 0; // backPtr_a[t][0] = 0 startloop += 1; } for (auto i = startloop; i < end; i++) { - auto x0 = alphas_a[prevIdxOffset][i]; - auto x1 = alphas_a[prevIdxOffset][i - 1]; + auto x0 = alphas_a[prevIdxOffset * S + i]; // alphas_a[prevIdxOffset][i]; + auto x1 = alphas_a[prevIdxOffset * S + i - 1]; // alphas_a[prevIdxOffset][i - 1]; auto x2 = -std::numeric_limits::infinity(); - auto labelIdx = (i % 2 == 0) ? blank : targets_a[batchIndex][i / 2]; + auto labelIdx = (i % 2 == 0) ? blank : targets_a.index(batchIndex, i / 2); // In CTC, the optimal path may optionally chose to skip a blank label. // x2 represents skipping a letter, and can only happen if we're not // currently on a blank_label, and we're not on a repeat letter // (i != 1) just ensures we don't access targets[i - 2] if its i < 2 if (i % 2 != 0 && i != 1 && - targets_a[batchIndex][i / 2] != targets_a[batchIndex][i / 2 - 1]) { - x2 = alphas_a[prevIdxOffset][i - 2]; + targets_a.index(batchIndex, i / 2) != targets_a.index(batchIndex, i / 2 - 1)) { + x2 = alphas_a[prevIdxOffset * S + i - 2]; // alphas_a[prevIdxOffset][i - 2]; } scalar_t result = 0.0; if (x2 > x1 && x2 > x0) { result = x2; - backPtr_a[t][i] = 2; + backPtr_a[t * S + i] = 2; // backPtr_a[t][i] = 2 } else if (x1 > x0 && x1 > x2) { result = x1; - backPtr_a[t][i] = 1; + backPtr_a[t * S + i] = 1; // backPtr_a[t][i] = 1 } else { result = x0; - backPtr_a[t][i] = 0; + backPtr_a[t * S + i] = 0; // backPtr_a[t][i] = 0 } - alphas_a[curIdxOffset][i] = result + logProbs_a[batchIndex][t][labelIdx]; + + alphas_a[curIdxOffset * S + i] = result + logProbs_a.index(batchIndex, t, labelIdx); // alphas_a[curIdxOffset][i] } } auto idx1 = (T - 1) % 2; - auto ltrIdx = alphas_a[idx1][S - 1] > alphas_a[idx1][S - 2] ? S - 1 : S - 2; + auto ltrIdx = alphas_a[S * idx1 + S - 1] > + alphas_a[S * idx1 + S - 2] ? S - 1 : S - 2; // alphas_a[idx1][S - 1], alphas_a[idx1][S - 2] + delete[] alphas_a; // path stores the token index for each time step after force alignment. for (auto t = T - 1; t > -1; t--) { - auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a[batchIndex][ltrIdx / 2]; - paths_a[batchIndex][t] = lbl_idx; - ltrIdx -= backPtr_a[t][ltrIdx]; + auto lbl_idx = ltrIdx % 2 == 0 ? blank : targets_a.index(batchIndex, ltrIdx / 2); + paths_a.set_index(lbl_idx, batchIndex, t); + ltrIdx -= backPtr_a[t * S + ltrIdx]; // backPtr_a[t][ltrIdx] } + delete[] backPtr_a; } -std::tuple compute( - const torch::Tensor& logProbs, - const torch::Tensor& targets, - const torch::Tensor& inputLengths, - const torch::Tensor& targetLengths, +std::tuple compute( + const Tensor& logProbs, + const Tensor& targets, + const Tensor& inputLengths, + const Tensor& targetLengths, const int64_t blank) { TORCH_CHECK(logProbs.is_cpu(), "log_probs must be a CPU tensor"); TORCH_CHECK(targets.is_cpu(), "targets must be a CPU tensor"); TORCH_CHECK( - logProbs.device() == targets.device(), + logProbs.get_device() == targets.get_device(), "log_probs and targets need to be on the same device"); TORCH_CHECK( - logProbs.dtype() == torch::kFloat64 || - logProbs.dtype() == torch::kFloat32 || - logProbs.dtype() == torch::kFloat16, + logProbs.dtype() == aoti_torch_dtype_float64() || + logProbs.dtype() == aoti_torch_dtype_float32() || + logProbs.dtype() == aoti_torch_dtype_float16(), "log_probs must be float64, float32 or float16 (half) type"); TORCH_CHECK( - targets.dtype() == torch::kInt32 || targets.dtype() == torch::kInt64, + targets.dtype() == aoti_torch_dtype_int32() || targets.dtype() == aoti_torch_dtype_int64(), "targets must be int32 or int64 type"); TORCH_CHECK(logProbs.is_contiguous(), "log_probs must be contiguous"); TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); @@ -163,28 +179,43 @@ std::tuple compute( blank >= 0 && blank < logProbs.size(-1), "blank must be within [0, num classes)"); - TORCH_CHECK( - logProbs.size(1) == at::max(inputLengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(targetLengths).item().toInt(), - "target length mismatch"); + // TODO: Requires port of `max` and `item` operators. + // TORCH_CHECK( + // logProbs.size(1) == at::max(inputLengths).item().toInt(), + // "input length mismatch"); + // TORCH_CHECK( + // targets.size(1) == at::max(targetLengths).item().toInt(), + // "target length mismatch"); const auto B = logProbs.size(0); const auto T = logProbs.size(1); - auto paths = torch::zeros( - {B, T}, - torch::TensorOptions().device(targets.device()).dtype(targets.dtype())); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - logProbs.scalar_type(), "forced_align_impl", [&] { - if (targets.scalar_type() == torch::kInt64) { - forced_align_impl( - logProbs, targets, blank, paths); - } else { - forced_align_impl( - logProbs, targets, blank, paths); - } - }); + + int64_t paths_size[2] = {B, T}; + int64_t paths_stride[2] = {T, 1}; + AtenTensorHandle paths_h; + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + aoti_torch_empty_strided(2, paths_size, paths_stride, targets.dtype(), targets_device, targets.get_device(), &paths_h); + auto paths = Tensor(paths_h); + + + if (targets.dtype() == aoti_torch_dtype_int64()) { + if (logProbs.dtype() == aoti_torch_dtype_float64()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logProbs.dtype() == aoti_torch_dtype_float32()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logProbs.dtype() == aoti_torch_dtype_float16()) { + forced_align_impl(logProbs, targets, blank, paths); + } + } else if (targets.dtype() == aoti_torch_dtype_int32()) { + if (logProbs.dtype() == aoti_torch_dtype_float64()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logProbs.dtype() == aoti_torch_dtype_float32()) { + forced_align_impl(logProbs, targets, blank, paths); + } else if (logProbs.dtype() == aoti_torch_dtype_float16()) { + forced_align_impl(logProbs, targets, blank, paths); + } + } return std::make_tuple( paths, logProbs @@ -192,10 +223,20 @@ std::tuple compute( } +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); + int64_t blank = to(stack[4]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), blank); + stack[0] = from(std::get<0>(result)); + stack[1] = from(std::get<1>(result)); +} - -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("forced_align", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("forced_align", &boxed_compute); } } // namespace cpu diff --git a/test/torchaudio_unittest/accessor_test.py b/test/torchaudio_unittest/accessor_test.py new file mode 100644 index 0000000000..db14258dc6 --- /dev/null +++ b/test/torchaudio_unittest/accessor_test.py @@ -0,0 +1,7 @@ +import torch +from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE + +if _IS_TORCHAUDIO_EXT_AVAILABLE: + def test_accessor(): + tensor = torch.randint(1000, (5,4,3)) + assert torch.ops.torchaudio._test_accessor(tensor)