diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 85bc227cd6..c7813d1222 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -22,19 +22,13 @@ if(BUILD_RNNT) list( APPEND sources - rnnt/cpu/compute_alphas.cpp - rnnt/cpu/compute_betas.cpp rnnt/cpu/compute.cpp - rnnt/compute_alphas.cpp - rnnt/compute_betas.cpp rnnt/compute.cpp ) if (USE_CUDA) list( APPEND sources - rnnt/gpu/compute_alphas.cu - rnnt/gpu/compute_betas.cu rnnt/gpu/compute.cu ) endif() diff --git a/src/libtorchaudio/forced_align/cpu/compute.cpp b/src/libtorchaudio/forced_align/cpu/compute.cpp index 0ddd21b126..4b87b20e5f 100644 --- a/src/libtorchaudio/forced_align/cpu/compute.cpp +++ b/src/libtorchaudio/forced_align/cpu/compute.cpp @@ -1,8 +1,5 @@ #include #include -#include -#include -#include #include using namespace std; diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 5aba334cee..f5831922f2 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -1,34 +1,12 @@ -#include +#include -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - static auto op = torch::Dispatcher::singleton() - .findSchemaOrThrow("torchaudio::rnnt_loss", "") - .typed(); - return op.call( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); -} - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( - "rnnt_loss(Tensor logits," + "rnnt_loss_forward(Tensor logits," "Tensor targets," "Tensor logit_lengths," "Tensor target_lengths," "int blank," "float clamp," - "bool fused_log_softmax) -> (Tensor, Tensor?)"); - m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); + "bool fused_log_softmax) -> (Tensor, Tensor)"); } diff --git a/src/libtorchaudio/rnnt/compute.h b/src/libtorchaudio/rnnt/compute.h deleted file mode 100644 index ed2dd0c37e..0000000000 --- a/src/libtorchaudio/rnnt/compute.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax); diff --git a/src/libtorchaudio/rnnt/compute_alphas.cpp b/src/libtorchaudio/rnnt/compute_alphas.cpp deleted file mode 100644 index adbcc1c8e7..0000000000 --- a/src/libtorchaudio/rnnt/compute_alphas.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss_alphas(Tensor logits," - "Tensor targets," - "Tensor logit_lengths," - "Tensor target_lengths," - "int blank," - "float clamp) -> Tensor"); -} diff --git a/src/libtorchaudio/rnnt/compute_betas.cpp b/src/libtorchaudio/rnnt/compute_betas.cpp deleted file mode 100644 index 7728838137..0000000000 --- a/src/libtorchaudio/rnnt/compute_betas.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss_betas(Tensor logits," - "Tensor targets," - "Tensor logit_lengths," - "Tensor target_lengths," - "int blank," - "float clamp) -> Tensor"); -} diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index e5e13042cd..1ef7df05c1 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -1,76 +1,89 @@ #include -#include +#include +#include namespace torchaudio { namespace rnnt { namespace cpu { +using torch::stable::Tensor; +using torch::headeronly::ScalarType; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const Tensor& logits, + const Tensor& targets, + const Tensor& logit_lengths, + const Tensor& target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), + STD_TORCH_CHECK(logits.is_cpu(), "logits must be on CPU"); + + STD_TORCH_CHECK( + targets.is_cpu(), "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), + STD_TORCH_CHECK( + logit_lengths.is_cpu(), "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), + STD_TORCH_CHECK( + target_lengths.is_cpu(), "logits and target_lengths must be on the same device"); - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, + STD_TORCH_CHECK( + logits.scalar_type() == ScalarType::Float || logits.scalar_type() == ScalarType::Half, "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, + + STD_TORCH_CHECK(targets.scalar_type() == ScalarType::Int, "targets must be int32 type"); + + STD_TORCH_CHECK( + logit_lengths.scalar_type() == ScalarType::Int, "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, + STD_TORCH_CHECK( + target_lengths.scalar_type() == ScalarType::Int, "target_lengths must be int32 type"); - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + STD_TORCH_CHECK( logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK( target_lengths.is_contiguous(), "target_lengths must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK( logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( + STD_TORCH_CHECK( targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); + STD_TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); + STD_TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - TORCH_CHECK( + STD_TORCH_CHECK( logit_lengths.size(0) == logits.size(0), "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( + STD_TORCH_CHECK( target_lengths.size(0) == logits.size(0), "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( + STD_TORCH_CHECK( targets.size(0) == logits.size(0), "batch dimension mismatch between logits and targets"); - TORCH_CHECK( + STD_TORCH_CHECK( blank >= 0 && blank < logits.size(-1), "blank must be within [0, logits.shape[-1])"); - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), + auto max_ivalue = [](const Tensor& t) { + // TODO: eliminate const_cast after pytorch/pytorch#161826 is fixed + return reinterpret_cast(torch::stable::amax(const_cast(t), {}).data_ptr())[0]; + }; + + STD_TORCH_CHECK( + logits.size(1) == max_ivalue(logit_lengths), "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, + STD_TORCH_CHECK( + logits.size(2) == max_ivalue(target_lengths) + 1, "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), + STD_TORCH_CHECK( + targets.size(1) + 1 == logits.size(2), "target length mismatch"); Options options; @@ -82,67 +95,70 @@ std::tuple> compute( options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); options.device_ = CPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + Tensor costs = torch::stable::new_empty(logits, {options.batchSize_ * options.nHypos_}); + Tensor gradients = torch::stable::empty_like(logits); + torch::stable::fill_(gradients, 0.0); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + Tensor int_workspace = torch::stable::new_empty(logits, {IntWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Int); + Tensor float_workspace = torch::stable::new_empty(logits, {DtypeWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Float); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_data=*/reinterpret_cast(float_workspace.data_ptr()), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), + /*int_data=*/reinterpret_cast(int_workspace.data_ptr()), /*int_size=*/int_workspace.numel()); switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + case ScalarType::Float: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr())); break; } - case torch::ScalarType::Half: { + case ScalarType::Half: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr())); break; } default: { - break; + STD_TORCH_CHECK(false, "unreachable"); } }; return std::make_tuple(costs, gradients); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss", &compute); +void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + STD_TORCH_CHECK(num_args == 7, "num_args must be 7"); + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); + std::tuple res = compute( + /*logits*/to(stack[0]), + /*targets*/to(stack[1]), + /*logit_lengths*/to(stack[2]), + /*target_lengths*/to(stack[3]), + /*blank*/float(to(stack[4])), + /*clamp*/to(stack[5]), + /*fused_log_softmax*/to(stack[6])); + stack[0] = from(std::get<0>(res)); + stack[1] = from(std::get<1>(res)); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss_forward", &boxed_rnnt_loss); } } // namespace cpu diff --git a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp deleted file mode 100644 index f0d2e009fe..0000000000 --- a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace cpu { - -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); - options.device_ = CPU; - - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); -} - -} // namespace cpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp b/src/libtorchaudio/rnnt/cpu/compute_betas.cpp deleted file mode 100644 index 025ca35f60..0000000000 --- a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace cpu { - -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); - options.device_ = CPU; - - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeBetas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_betas", &compute_betas); -} - -} // namespace cpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 43dae68027..eac0759ab7 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -1,77 +1,94 @@ -#include #include -#include + +#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { +using torch::stable::Tensor; +using torch::headeronly::ScalarType; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const Tensor& logits, + const Tensor& targets, + const Tensor& logit_lengths, + const Tensor& target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), + STD_TORCH_CHECK(logits.is_cuda(), "logits must be on CUDA"); + + STD_TORCH_CHECK( + targets.is_cuda() && targets.get_device_index() == logits.get_device_index(), "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), + STD_TORCH_CHECK( + logit_lengths.is_cuda() && logit_lengths.get_device_index() == logits.get_device_index(), "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), + STD_TORCH_CHECK( + target_lengths.is_cuda() && target_lengths.get_device_index() == logits.get_device_index(), "logits and target_lengths must be on the same device"); - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, + STD_TORCH_CHECK( + logits.scalar_type() == ScalarType::Float || logits.scalar_type() == ScalarType::Half, "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, + + STD_TORCH_CHECK(targets.scalar_type() == ScalarType::Int, "targets must be int32 type"); + + STD_TORCH_CHECK( + logit_lengths.scalar_type() == ScalarType::Int, "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, + STD_TORCH_CHECK( + target_lengths.scalar_type() == ScalarType::Int, "target_lengths must be int32 type"); - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); + STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); + STD_TORCH_CHECK( logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK( target_lengths.is_contiguous(), "target_lengths must be contiguous"); - TORCH_CHECK( + STD_TORCH_CHECK( logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( + STD_TORCH_CHECK( targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); + STD_TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); + STD_TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - TORCH_CHECK( + STD_TORCH_CHECK( logit_lengths.size(0) == logits.size(0), "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( + STD_TORCH_CHECK( target_lengths.size(0) == logits.size(0), "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( + STD_TORCH_CHECK( targets.size(0) == logits.size(0), "batch dimension mismatch between logits and targets"); - TORCH_CHECK( + STD_TORCH_CHECK( blank >= 0 && blank < logits.size(-1), "blank must be within [0, logits.shape[-1])"); - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), + auto max_ivalue = [](const Tensor& t) { + // TODO: eliminate const_cast after pytorch/pytorch#161826 is fixed + int32_t value; + C10_CUDA_CHECK(cudaMemcpy(&value, torch::stable::amax(const_cast(t), {}).data_ptr(), sizeof(int32_t), cudaMemcpyDeviceToHost)); + return value; + }; + + STD_TORCH_CHECK( + logits.size(1) == max_ivalue(logit_lengths), "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, + STD_TORCH_CHECK( + logits.size(2) == max_ivalue(target_lengths) + 1, "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), + STD_TORCH_CHECK( + targets.size(1) + 1 == logits.size(2), "target length mismatch"); Options options; @@ -83,69 +100,72 @@ std::tuple> compute( options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); options.stream_ = at::cuda::getCurrentCUDAStream(); cudaSetDevice(logits.get_device()); options.device_ = GPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + Tensor costs = torch::stable::new_empty(logits, {options.batchSize_ * options.nHypos_}); + Tensor gradients = torch::stable::empty_like(logits); + torch::stable::fill_(gradients, 0.0); - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + Tensor int_workspace = torch::stable::new_empty(logits, {IntWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Int); + Tensor float_workspace = torch::stable::new_empty(logits, {DtypeWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Float); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), + /*dtype_data=*/reinterpret_cast(float_workspace.data_ptr()), /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), + /*int_data=*/reinterpret_cast(int_workspace.data_ptr()), /*int_size=*/int_workspace.numel()); switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + case ScalarType::Float: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr())); break; } - case torch::ScalarType::Half: { + case ScalarType::Half: { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); + /*logits=*/reinterpret_cast(logits.data_ptr()), + /*targets=*/reinterpret_cast(targets.data_ptr()), + /*srcLengths=*/reinterpret_cast(logit_lengths.data_ptr()), + /*tgtLengths=*/reinterpret_cast(target_lengths.data_ptr()), + /*costs=*/reinterpret_cast(costs.data_ptr()), + /*gradients=*/reinterpret_cast(gradients.data_ptr())); break; } default: { - break; + STD_TORCH_CHECK(false, "unreachable"); } }; return std::make_tuple(costs, gradients); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss", &compute); +void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + STD_TORCH_CHECK(num_args == 7, "num_args must be 7"); + STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2"); + std::tuple res = compute( + /*logits*/to(stack[0]), + /*targets*/to(stack[1]), + /*logit_lengths*/to(stack[2]), + /*target_lengths*/to(stack[3]), + /*blank*/float(to(stack[4])), + /*clamp*/to(stack[5]), + /*fused_log_softmax*/to(stack[6])); + stack[0] = from(std::get<0>(res)); + stack[1] = from(std::get<1>(res)); +} + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss_forward", &boxed_rnnt_loss); } } // namespace gpu diff --git a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu deleted file mode 100644 index bde40daa9f..0000000000 --- a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu +++ /dev/null @@ -1,73 +0,0 @@ -#include -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace gpu { - -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); - options.device_ = GPU; - - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); -} - -} // namespace gpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/gpu/compute_betas.cu b/src/libtorchaudio/rnnt/gpu/compute_betas.cu deleted file mode 100644 index 18857c4388..0000000000 --- a/src/libtorchaudio/rnnt/gpu/compute_betas.cu +++ /dev/null @@ -1,78 +0,0 @@ -#include -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace gpu { - -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); - options.device_ = GPU; - - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeBetas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_betas", &compute_betas); -} - -} // namespace gpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/utils.cpp b/src/libtorchaudio/utils.cpp index 4789f0c0a8..5d0ab12d81 100644 --- a/src/libtorchaudio/utils.cpp +++ b/src/libtorchaudio/utils.cpp @@ -5,6 +5,11 @@ #include #endif +// Include of tensor.h defines the implementation of +// torch::stable::scalar_type function. No other source file should +// include torch/csrc/stable/tensor.h! +#include + namespace torchaudio { bool is_rir_available() {