diff --git a/src/libtorchaudio/CMakeLists.txt b/src/libtorchaudio/CMakeLists.txt index 85bc227cd6..c7813d1222 100644 --- a/src/libtorchaudio/CMakeLists.txt +++ b/src/libtorchaudio/CMakeLists.txt @@ -22,19 +22,13 @@ if(BUILD_RNNT) list( APPEND sources - rnnt/cpu/compute_alphas.cpp - rnnt/cpu/compute_betas.cpp rnnt/cpu/compute.cpp - rnnt/compute_alphas.cpp - rnnt/compute_betas.cpp rnnt/compute.cpp ) if (USE_CUDA) list( APPEND sources - rnnt/gpu/compute_alphas.cu - rnnt/gpu/compute_betas.cu rnnt/gpu/compute.cu ) endif() diff --git a/src/libtorchaudio/rnnt/compute.cpp b/src/libtorchaudio/rnnt/compute.cpp index 5aba334cee..867542e4e7 100644 --- a/src/libtorchaudio/rnnt/compute.cpp +++ b/src/libtorchaudio/rnnt/compute.cpp @@ -1,27 +1,8 @@ -#include +#include +#include -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax = true) { - static auto op = torch::Dispatcher::singleton() - .findSchemaOrThrow("torchaudio::rnnt_loss", "") - .typed(); - return op.call( - logits, - targets, - logit_lengths, - target_lengths, - blank, - clamp, - fused_log_softmax); -} -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { m.def( "rnnt_loss(Tensor logits," "Tensor targets," @@ -29,6 +10,5 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { "Tensor target_lengths," "int blank," "float clamp," - "bool fused_log_softmax) -> (Tensor, Tensor?)"); - m.def("torchaudio::rnnt_loss_forward", &rnnt_loss); + "bool fused_log_softmax) -> (Tensor, Tensor)"); } diff --git a/src/libtorchaudio/rnnt/compute.h b/src/libtorchaudio/rnnt/compute.h deleted file mode 100644 index ed2dd0c37e..0000000000 --- a/src/libtorchaudio/rnnt/compute.h +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -#include - -std::tuple> rnnt_loss( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp, - bool fused_log_softmax); diff --git a/src/libtorchaudio/rnnt/compute_alphas.cpp b/src/libtorchaudio/rnnt/compute_alphas.cpp deleted file mode 100644 index adbcc1c8e7..0000000000 --- a/src/libtorchaudio/rnnt/compute_alphas.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss_alphas(Tensor logits," - "Tensor targets," - "Tensor logit_lengths," - "Tensor target_lengths," - "int blank," - "float clamp) -> Tensor"); -} diff --git a/src/libtorchaudio/rnnt/compute_betas.cpp b/src/libtorchaudio/rnnt/compute_betas.cpp deleted file mode 100644 index 7728838137..0000000000 --- a/src/libtorchaudio/rnnt/compute_betas.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include - -TORCH_LIBRARY_FRAGMENT(torchaudio, m) { - m.def( - "rnnt_loss_betas(Tensor logits," - "Tensor targets," - "Tensor logit_lengths," - "Tensor target_lengths," - "int blank," - "float clamp) -> Tensor"); -} diff --git a/src/libtorchaudio/rnnt/cpu/compute.cpp b/src/libtorchaudio/rnnt/cpu/compute.cpp index e5e13042cd..4150ea98f1 100644 --- a/src/libtorchaudio/rnnt/cpu/compute.cpp +++ b/src/libtorchaudio/rnnt/cpu/compute.cpp @@ -1,148 +1,216 @@ #include -#include +#include +#include +#include + + +#include namespace torchaudio { namespace rnnt { namespace cpu { +using torch::stable::Tensor; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const Tensor logits, + const Tensor targets, + const Tensor logit_lengths, + const Tensor target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), - "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), - "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), - "logits and target_lengths must be on the same device"); - - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, - "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, - "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, - "target_lengths must be int32 type"); - - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( - logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( - target_lengths.is_contiguous(), "target_lengths must be contiguous"); - - TORCH_CHECK( - logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( - targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - - TORCH_CHECK( - logit_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( - target_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( - targets.size(0) == logits.size(0), - "batch dimension mismatch between logits and targets"); - - TORCH_CHECK( - blank >= 0 && blank < logits.size(-1), - "blank must be within [0, logits.shape[-1])"); - - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, - "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), - "target length mismatch"); + + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + int32_t logit_lengths_device; + aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); + int32_t target_lengths_device; + aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); + + AOTI_TORCH_CHECK(logits_device == targets_device); + AOTI_TORCH_CHECK(logits_device == logit_lengths_device); + AOTI_TORCH_CHECK(logits_device == target_lengths_device); + + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t targets_dtype; + aoti_torch_get_dtype(targets.get(), &targets_dtype); + AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t logit_lengths_dtype; + aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); + AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || + logit_lengths_dtype == aoti_torch_dtype_float16()); + + int32_t target_lengths_dtype; + aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); + AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || + target_lengths_dtype == aoti_torch_dtype_float16()); + + bool bool_tmp; + aoti_torch_is_contiguous(logits.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(targets.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); + + int64_t int_tmp; + aoti_torch_get_dim(logits.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 4); + aoti_torch_get_dim(targets.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 2); + aoti_torch_get_dim(logit_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + aoti_torch_get_dim(target_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + + int64_t logit_lengths_size; + aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); + int64_t logits_size; + aoti_torch_get_size(logits.get(), 0, &logits_size); + AOTI_TORCH_CHECK(logit_lengths_size == logits_size); + int64_t target_lengths_size; + aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); + AOTI_TORCH_CHECK(target_lengths_size == logits_size); + int64_t targets_size; + aoti_torch_get_size(targets.get(), 0, &targets_size); + AOTI_TORCH_CHECK(targets_size == logits_size); + + // TORCH_CHECK( + // blank >= 0 && blank < logits.size(-1), + // "blank must be within [0, logits.shape[-1])"); + + // TORCH_CHECK( + // logits.size(1) == at::max(logit_lengths).item().toInt(), + // "input length mismatch"); + // TORCH_CHECK( + // logits.size(2) == at::max(target_lengths).item().toInt() + 1, + // "output length mismatch"); + // TORCH_CHECK( + // targets.size(1) == at::max(target_lengths).item().toInt(), + // "target length mismatch"); Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); + options.batchSize_ = (int)logit_lengths_size; + options.nHypos_ = (int)(target_lengths_size / options.batchSize_); + aoti_torch_get_size(logits.get(), 1, &int_tmp); + options.maxSrcLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 2, &int_tmp); + options.maxTgtLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 3, &int_tmp); + options.numTargets_ = (int)int_tmp; options.blank_ = blank; options.clamp_ = clamp; options.fusedLogSmax_ = fused_log_softmax; - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); + AOTI_TORCH_CHECK(logits_device == aoti_torch_device_type_cpu()); options.device_ = CPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + AtenTensorHandle gradients; + aoti_torch_clone(logits.get(), &gradients); + aoti_torch_zero_(gradients); + + AtenTensorHandle int_workspace; + int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + int64_t float_sizes[1] = {DtypeWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, float_sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *costs_ptr; + aoti_torch_get_data_ptr(costs, &costs_ptr); + + void *grads_ptr; + aoti_torch_get_data_ptr(gradients, &grads_ptr); + + if (logits_dtype == aoti_torch_dtype_float32()) { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; - } - case torch::ScalarType::Half: { + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)costs_ptr, + /*gradients=*/(float*)grads_ptr); + } else { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; + /*logits=*/(c10::Half*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(c10::Half*)costs_ptr, + /*gradients=*/(c10::Half*)grads_ptr); } - default: { - break; - } - }; - return std::make_tuple(costs, gradients); + return std::make_tuple(Tensor(costs), Tensor(gradients)); +} + +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + bool fused_log_softmax = to(stack[6]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp, fused_log_softmax); + stack[0] = from(std::get<0>(result)); + stack[1] = from(std::get<1>(result)); } -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss", &compute); + +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { + m.impl("rnnt_loss", &boxed_compute); } } // namespace cpu diff --git a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp b/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp deleted file mode 100644 index f0d2e009fe..0000000000 --- a/src/libtorchaudio/rnnt/cpu/compute_alphas.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace cpu { - -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); - options.device_ = CPU; - - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); -} - -} // namespace cpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp b/src/libtorchaudio/rnnt/cpu/compute_betas.cpp deleted file mode 100644 index 025ca35f60..0000000000 --- a/src/libtorchaudio/rnnt/cpu/compute_betas.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace cpu { - -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU); - options.device_ = CPU; - - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeBetas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*srcLengths=*/logit_lengths.data_ptr(), - /*tgtLengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { - m.impl("rnnt_loss_betas", &compute_betas); -} - -} // namespace cpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/gpu/compute.cu b/src/libtorchaudio/rnnt/gpu/compute.cu index 43dae68027..e56026515c 100644 --- a/src/libtorchaudio/rnnt/gpu/compute.cu +++ b/src/libtorchaudio/rnnt/gpu/compute.cu @@ -1,151 +1,217 @@ #include #include -#include +#include +#include +#include namespace torchaudio { namespace rnnt { namespace gpu { +using torch::stable::Tensor; + // Entry point into RNNT Loss -std::tuple> compute( - torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, +std::tuple compute( + const Tensor logits, + const Tensor targets, + const Tensor logit_lengths, + const Tensor target_lengths, int64_t blank, double clamp, bool fused_log_softmax = true) { - TORCH_CHECK( - logits.device().type() == targets.device().type(), - "logits and targets must be on the same device"); - TORCH_CHECK( - logits.device().type() == logit_lengths.device().type(), - "logits and logit_lengths must be on the same device"); - TORCH_CHECK( - logits.device().type() == target_lengths.device().type(), - "logits and target_lengths must be on the same device"); - - TORCH_CHECK( - logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16, - "logits must be float32 or float16 (half) type"); - TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type"); - TORCH_CHECK( - logit_lengths.dtype() == torch::kInt32, - "logit_lengths must be int32 type"); - TORCH_CHECK( - target_lengths.dtype() == torch::kInt32, - "target_lengths must be int32 type"); - - TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous"); - TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous"); - TORCH_CHECK( - logit_lengths.is_contiguous(), "logit_lengths must be contiguous"); - TORCH_CHECK( - target_lengths.is_contiguous(), "target_lengths must be contiguous"); - - TORCH_CHECK( - logits.dim() == 4, "logits must be 4-D (batch, time, target, class)"); - TORCH_CHECK( - targets.dim() == 2, "targets must be 2-D (batch, max target length)"); - TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D"); - TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D"); - - TORCH_CHECK( - logit_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and logit_lengths"); - TORCH_CHECK( - target_lengths.size(0) == logits.size(0), - "batch dimension mismatch between logits and target_lengths"); - TORCH_CHECK( - targets.size(0) == logits.size(0), - "batch dimension mismatch between logits and targets"); - - TORCH_CHECK( - blank >= 0 && blank < logits.size(-1), - "blank must be within [0, logits.shape[-1])"); - - TORCH_CHECK( - logits.size(1) == at::max(logit_lengths).item().toInt(), - "input length mismatch"); - TORCH_CHECK( - logits.size(2) == at::max(target_lengths).item().toInt() + 1, - "output length mismatch"); - TORCH_CHECK( - targets.size(1) == at::max(target_lengths).item().toInt(), - "target length mismatch"); - - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - options.fusedLogSmax_ = fused_log_softmax; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); - options.device_ = GPU; - torch::Tensor costs = torch::empty( - options.batchSize_ * options.nHypos_, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - std::optional gradients = torch::zeros_like(logits); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); + int32_t logits_device; + aoti_torch_get_device_type(logits.get(), &logits_device); + int32_t targets_device; + aoti_torch_get_device_type(targets.get(), &targets_device); + int32_t logit_lengths_device; + aoti_torch_get_device_type(logit_lengths.get(), &logit_lengths_device); + int32_t target_lengths_device; + aoti_torch_get_device_type(target_lengths.get(), &target_lengths_device); + + AOTI_TORCH_CHECK(logits_device == targets_device); + AOTI_TORCH_CHECK(logits_device == logit_lengths_device); + AOTI_TORCH_CHECK(logits_device == target_lengths_device); + + int32_t logits_dtype; + aoti_torch_get_dtype(logits.get(), &logits_dtype); + AOTI_TORCH_CHECK(logits_dtype == aoti_torch_dtype_float32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t targets_dtype; + aoti_torch_get_dtype(targets.get(), &targets_dtype); + AOTI_TORCH_CHECK(targets_dtype == aoti_torch_dtype_int32() || + logits_dtype == aoti_torch_dtype_float16()); + + int32_t logit_lengths_dtype; + aoti_torch_get_dtype(logit_lengths.get(), &logit_lengths_dtype); + AOTI_TORCH_CHECK(logit_lengths_dtype == aoti_torch_dtype_int32() || + logit_lengths_dtype == aoti_torch_dtype_float16()); + + int32_t target_lengths_dtype; + aoti_torch_get_dtype(target_lengths.get(), &target_lengths_dtype); + AOTI_TORCH_CHECK(target_lengths_dtype == aoti_torch_dtype_int32() || + target_lengths_dtype == aoti_torch_dtype_float16()); + + bool bool_tmp; + aoti_torch_is_contiguous(logits.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(targets.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(logit_lengths.get(), &bool_tmp); + AOTI_TORCH_CHECK(bool_tmp); + aoti_torch_is_contiguous(target_lengths.get(), &bool_tmp); + + int64_t int_tmp; + aoti_torch_get_dim(logits.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 4); + aoti_torch_get_dim(targets.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 2); + aoti_torch_get_dim(logit_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + aoti_torch_get_dim(target_lengths.get(), &int_tmp); + AOTI_TORCH_CHECK(int_tmp == 1); + + int64_t logit_lengths_size; + aoti_torch_get_size(logit_lengths.get(), 0, &logit_lengths_size); + int64_t logits_size; + aoti_torch_get_size(logits.get(), 0, &logits_size); + AOTI_TORCH_CHECK(logit_lengths_size == logits_size); + int64_t target_lengths_size; + aoti_torch_get_size(target_lengths.get(), 0, &target_lengths_size); + AOTI_TORCH_CHECK(target_lengths_size == logits_size); + int64_t targets_size; + aoti_torch_get_size(targets.get(), 0, &targets_size); + AOTI_TORCH_CHECK(targets_size == logits_size); + + // TORCH_CHECK( + // blank >= 0 && blank < logits.size(-1), + // "blank must be within [0, logits.shape[-1])"); + + // TORCH_CHECK( + // logits.size(1) == at::max(logit_lengths).item().toInt(), + // "input length mismatch"); + // TORCH_CHECK( + // logits.size(2) == at::max(target_lengths).item().toInt() + 1, + // "output length mismatch"); + // TORCH_CHECK( + // targets.size(1) == at::max(target_lengths).item().toInt(), + // "target length mismatch"); + + Options options; + options.batchSize_ = (int)logit_lengths_size; + options.nHypos_ = (int)target_lengths_size; + options.nHypos_ /= options.batchSize_; + aoti_torch_get_size(logits.get(), 1, &int_tmp); + options.maxSrcLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 2, &int_tmp); + options.maxTgtLen_ = (int)int_tmp; + aoti_torch_get_size(logits.get(), 3, &int_tmp); + options.numTargets_ = (int)int_tmp; + options.blank_ = blank; + options.clamp_ = clamp; + options.fusedLogSmax_ = fused_log_softmax; + + int32_t logits_device_index; + aoti_torch_get_device_index(logits.get(), &logits_device_index); + + TORCH_CHECK_EQ(logits_device, aoti_torch_device_type_cuda()); + aoti_torch_get_current_cuda_stream(logits_device_index, (void**)&options.stream_); + cudaSetDevice(logits_device); + options.device_ = GPU; - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); + int64_t cost_sizes[1] = {options.batchSize_ * options.nHypos_}; + int64_t stride1[1] = {1}; + AtenTensorHandle costs; + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); + + AtenTensorHandle gradients; + aoti_torch_clone(logits.get(), &gradients); + aoti_torch_zero_(gradients); + + AtenTensorHandle int_workspace; + int64_t int_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; + int64_t strides[1] = {1}; + aoti_torch_empty_strided(1, int_sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); + + AtenTensorHandle float_workspace; + int64_t float_sizes[1] = {DtypeWorkspace::ComputeSizeFromOptions(options)}; + aoti_torch_empty_strided(1, float_sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); + + int64_t float_numel; + aoti_torch_get_numel(float_workspace, &float_numel); + void *int_workspace_ptr; + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); + void *float_workspace_ptr; + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); + int64_t int_numel; + aoti_torch_get_numel(int_workspace, &int_numel); Workspace workspace( /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); + /*dtype_data=*/(float*)float_workspace_ptr, + /*dtype_size=*/float_numel, + /*int_data=*/(int*)int_workspace_ptr, + /*int_size=*/int_numel); + + void *logit_ptr; + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); + + void *target_ptr; + aoti_torch_get_data_ptr(targets.get(), &target_ptr); + + void *logit_len_ptr; + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); - switch (logits.scalar_type()) { - case torch::ScalarType::Float: { + void *target_len_ptr; + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); + + void *costs_ptr; + aoti_torch_get_data_ptr(costs, &costs_ptr); + + void *grads_ptr; + aoti_torch_get_data_ptr(gradients, &grads_ptr); + + if (logits_dtype == aoti_torch_dtype_float32()) { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; - } - case torch::ScalarType::Half: { + /*logits=*/(float*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(float*)costs_ptr, + /*gradients=*/(float*)grads_ptr); + } else { Compute( /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*gradients=*/gradients->data_ptr()); - break; + /*logits=*/(c10::Half*)logit_ptr, + /*targets=*/(int*)target_ptr, + /*logit_lengths=*/(int*)logit_len_ptr, + /*target_lengths=*/(int*)target_len_ptr, + /*costs=*/(c10::Half*)costs_ptr, + /*gradients=*/(c10::Half*)grads_ptr); } - default: { - break; - } - }; - return std::make_tuple(costs, gradients); + return std::make_tuple(Tensor(costs), Tensor(gradients)); +} + +void boxed_compute(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + Tensor t1(to(stack[0])); + Tensor t2(to(stack[1])); + Tensor t3(to(stack[2])); + Tensor t4(to(stack[3])); + int64_t blank = to(stack[4]); + double clamp = to(stack[5]); + bool fused_log_softmax = to(stack[6]); + auto result = compute( + std::move(t1), std::move(t2), std::move(t3), std::move(t4), + blank, clamp, fused_log_softmax); + stack[0] = from(std::get<0>(result)); + stack[1] = from(std::get<1>(result)); } -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss", &compute); +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { + m.impl("rnnt_loss", &boxed_compute); } } // namespace gpu diff --git a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu b/src/libtorchaudio/rnnt/gpu/compute_alphas.cu deleted file mode 100644 index bde40daa9f..0000000000 --- a/src/libtorchaudio/rnnt/gpu/compute_alphas.cu +++ /dev/null @@ -1,73 +0,0 @@ -#include -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace gpu { - -torch::Tensor compute_alphas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); - options.device_ = GPU; - - torch::Tensor alphas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeAlphas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*alphas=*/alphas.data_ptr()); - return alphas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_alphas", &compute_alphas); -} - -} // namespace gpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/libtorchaudio/rnnt/gpu/compute_betas.cu b/src/libtorchaudio/rnnt/gpu/compute_betas.cu deleted file mode 100644 index 18857c4388..0000000000 --- a/src/libtorchaudio/rnnt/gpu/compute_betas.cu +++ /dev/null @@ -1,78 +0,0 @@ -#include -#include -#include - -namespace torchaudio { -namespace rnnt { -namespace gpu { - -torch::Tensor compute_betas( - const torch::Tensor& logits, - const torch::Tensor& targets, - const torch::Tensor& logit_lengths, - const torch::Tensor& target_lengths, - int64_t blank, - double clamp) { - Options options; - options.batchSize_ = logit_lengths.size(0); - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); - options.maxSrcLen_ = logits.size(1); - options.maxTgtLen_ = logits.size(2); - options.numTargets_ = logits.size(3); - options.blank_ = blank; - options.clamp_ = clamp; - - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); - options.stream_ = at::cuda::getCurrentCUDAStream(); - cudaSetDevice(logits.get_device()); - options.device_ = GPU; - - torch::Tensor costs = torch::empty( - target_lengths.size(0), - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor betas = torch::zeros( - {options.batchSize_ * options.nHypos_, - options.maxSrcLen_, - options.maxTgtLen_}, - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); - - torch::Tensor int_workspace = torch::empty( - IntWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Int)); - - torch::Tensor float_workspace = torch::empty( - DtypeWorkspace::ComputeSizeFromOptions(options), - torch::TensorOptions() - .device(logits.device()) - .dtype(torch::ScalarType::Float)); - - Workspace workspace( - /*options=*/options, - /*dtype_data=*/float_workspace.data_ptr(), - /*dtype_size=*/float_workspace.numel(), - /*int_data=*/int_workspace.data_ptr(), - /*int_size=*/int_workspace.numel()); - - // Only support float, this is mainly to enable easy - // unit-testing - ComputeBetas( - /*workspace=*/workspace, - /*logits=*/logits.data_ptr(), - /*targets=*/targets.data_ptr(), - /*logit_lengths=*/logit_lengths.data_ptr(), - /*target_lengths=*/target_lengths.data_ptr(), - /*costs=*/costs.data_ptr(), - /*betas=*/betas.data_ptr()); - return betas; -} - -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { - m.impl("rnnt_loss_betas", &compute_betas); -} - -} // namespace gpu -} // namespace rnnt -} // namespace torchaudio diff --git a/src/torchaudio/functional/functional.py b/src/torchaudio/functional/functional.py index 810d1f51fc..7a3b15ac48 100644 --- a/src/torchaudio/functional/functional.py +++ b/src/torchaudio/functional/functional.py @@ -1763,7 +1763,7 @@ def _fix_waveform_shape( class RnntLoss(torch.autograd.Function): @staticmethod def forward(ctx, *args): - output, saved = torch.ops.torchaudio.rnnt_loss_forward(*args) + output, saved = torch.ops.torchaudio.rnnt_loss.default(*args) ctx.save_for_backward(saved) return output