Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@ if(BUILD_RNNT)
list(
APPEND
sources
rnnt/cpu/compute_alphas.cpp
rnnt/cpu/compute_betas.cpp
rnnt/cpu/compute.cpp
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
)
if (USE_CUDA)
list(
APPEND
sources
rnnt/gpu/compute_alphas.cu
rnnt/gpu/compute_betas.cu
rnnt/gpu/compute.cu
)
endif()
Expand Down
3 changes: 0 additions & 3 deletions src/libtorchaudio/forced_align/cpu/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#include <torch/script.h>
#include <torch/torch.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/csrc/stable/ops.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

using namespace std;
Expand Down
30 changes: 4 additions & 26 deletions src/libtorchaudio/rnnt/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,12 @@
#include <libtorchaudio/rnnt/compute.h>
#include <torch/csrc/stable/library.h>

std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor logits,"
"rnnt_loss_forward(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax) -> (Tensor, Tensor?)");
m.def("torchaudio::rnnt_loss_forward", &rnnt_loss);
"bool fused_log_softmax) -> (Tensor, Tensor)");
}
12 changes: 0 additions & 12 deletions src/libtorchaudio/rnnt/compute.h

This file was deleted.

11 changes: 0 additions & 11 deletions src/libtorchaudio/rnnt/compute_alphas.cpp

This file was deleted.

11 changes: 0 additions & 11 deletions src/libtorchaudio/rnnt/compute_betas.cpp

This file was deleted.

162 changes: 89 additions & 73 deletions src/libtorchaudio/rnnt/cpu/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,76 +1,89 @@
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
#include <torch/script.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/stable/ops.h>

namespace torchaudio {
namespace rnnt {
namespace cpu {

using torch::stable::Tensor;
using torch::headeronly::ScalarType;

// Entry point into RNNT Loss
std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
std::tuple<Tensor, Tensor> compute(
const Tensor& logits,
const Tensor& targets,
const Tensor& logit_lengths,
const Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
TORCH_CHECK(
logits.device().type() == targets.device().type(),
STD_TORCH_CHECK(logits.is_cpu(), "logits must be on CPU");

STD_TORCH_CHECK(
targets.is_cpu(),
"logits and targets must be on the same device");
TORCH_CHECK(
logits.device().type() == logit_lengths.device().type(),
STD_TORCH_CHECK(
logit_lengths.is_cpu(),
"logits and logit_lengths must be on the same device");
TORCH_CHECK(
logits.device().type() == target_lengths.device().type(),
STD_TORCH_CHECK(
target_lengths.is_cpu(),
"logits and target_lengths must be on the same device");

TORCH_CHECK(
logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
STD_TORCH_CHECK(
logits.scalar_type() == ScalarType::Float || logits.scalar_type() == ScalarType::Half,
"logits must be float32 or float16 (half) type");
TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
TORCH_CHECK(
logit_lengths.dtype() == torch::kInt32,

STD_TORCH_CHECK(targets.scalar_type() == ScalarType::Int, "targets must be int32 type");

STD_TORCH_CHECK(
logit_lengths.scalar_type() == ScalarType::Int,
"logit_lengths must be int32 type");
TORCH_CHECK(
target_lengths.dtype() == torch::kInt32,
STD_TORCH_CHECK(
target_lengths.scalar_type() == ScalarType::Int,
"target_lengths must be int32 type");

TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
TORCH_CHECK(
STD_TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
STD_TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
STD_TORCH_CHECK(
logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
TORCH_CHECK(
STD_TORCH_CHECK(
target_lengths.is_contiguous(), "target_lengths must be contiguous");

TORCH_CHECK(
STD_TORCH_CHECK(
logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
TORCH_CHECK(
STD_TORCH_CHECK(
targets.dim() == 2, "targets must be 2-D (batch, max target length)");
TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
STD_TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
STD_TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");

TORCH_CHECK(
STD_TORCH_CHECK(
logit_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and logit_lengths");
TORCH_CHECK(
STD_TORCH_CHECK(
target_lengths.size(0) == logits.size(0),
"batch dimension mismatch between logits and target_lengths");
TORCH_CHECK(
STD_TORCH_CHECK(
targets.size(0) == logits.size(0),
"batch dimension mismatch between logits and targets");

TORCH_CHECK(
STD_TORCH_CHECK(
blank >= 0 && blank < logits.size(-1),
"blank must be within [0, logits.shape[-1])");

TORCH_CHECK(
logits.size(1) == at::max(logit_lengths).item().toInt(),
auto max_ivalue = [](const Tensor& t) {
// TODO: eliminate const_cast after pytorch/pytorch#161826 is fixed
return reinterpret_cast<int32_t*>(torch::stable::amax(const_cast<Tensor&>(t), {}).data_ptr())[0];
};

STD_TORCH_CHECK(
logits.size(1) == max_ivalue(logit_lengths),
"input length mismatch");
TORCH_CHECK(
logits.size(2) == at::max(target_lengths).item().toInt() + 1,
STD_TORCH_CHECK(
logits.size(2) == max_ivalue(target_lengths) + 1,
"output length mismatch");
TORCH_CHECK(
targets.size(1) == at::max(target_lengths).item().toInt(),
STD_TORCH_CHECK(
targets.size(1) + 1 == logits.size(2),
"target length mismatch");

Options options;
Expand All @@ -82,67 +95,70 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
options.blank_ = blank;
options.clamp_ = clamp;
options.fusedLogSmax_ = fused_log_softmax;

TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
options.device_ = CPU;

torch::Tensor costs = torch::empty(
options.batchSize_ * options.nHypos_,
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
std::optional<torch::Tensor> gradients = torch::zeros_like(logits);

torch::Tensor int_workspace = torch::empty(
IntWorkspace::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Int));
Tensor costs = torch::stable::new_empty(logits, {options.batchSize_ * options.nHypos_});
Tensor gradients = torch::stable::empty_like(logits);
torch::stable::fill_(gradients, 0.0);

torch::Tensor float_workspace = torch::empty(
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
torch::TensorOptions()
.device(logits.device())
.dtype(torch::ScalarType::Float));
Tensor int_workspace = torch::stable::new_empty(logits, {IntWorkspace::ComputeSizeFromOptions(options)}, ScalarType::Int);
Tensor float_workspace = torch::stable::new_empty(logits, {DtypeWorkspace<float>::ComputeSizeFromOptions(options)}, ScalarType::Float);

Workspace<float> workspace(
/*options=*/options,
/*dtype_data=*/float_workspace.data_ptr<float>(),
/*dtype_data=*/reinterpret_cast<float*>(float_workspace.data_ptr()),
/*dtype_size=*/float_workspace.numel(),
/*int_data=*/int_workspace.data_ptr<int>(),
/*int_data=*/reinterpret_cast<int*>(int_workspace.data_ptr()),
/*int_size=*/int_workspace.numel());

switch (logits.scalar_type()) {
case torch::ScalarType::Float: {
case ScalarType::Float: {
Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<float>(),
/*targets=*/targets.data_ptr<int>(),
/*srcLengths=*/logit_lengths.data_ptr<int>(),
/*tgtLengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<float>(),
/*gradients=*/gradients->data_ptr<float>());
/*logits=*/reinterpret_cast<float*>(logits.data_ptr()),
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
/*costs=*/reinterpret_cast<float*>(costs.data_ptr()),
/*gradients=*/reinterpret_cast<float*>(gradients.data_ptr()));
break;
}
case torch::ScalarType::Half: {
case ScalarType::Half: {
Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
/*workspace=*/workspace,
/*logits=*/logits.data_ptr<c10::Half>(),
/*targets=*/targets.data_ptr<int>(),
/*srcLengths=*/logit_lengths.data_ptr<int>(),
/*tgtLengths=*/target_lengths.data_ptr<int>(),
/*costs=*/costs.data_ptr<c10::Half>(),
/*gradients=*/gradients->data_ptr<c10::Half>());
/*logits=*/reinterpret_cast<c10::Half*>(logits.data_ptr()),
/*targets=*/reinterpret_cast<int*>(targets.data_ptr()),
/*srcLengths=*/reinterpret_cast<int*>(logit_lengths.data_ptr()),
/*tgtLengths=*/reinterpret_cast<int*>(target_lengths.data_ptr()),
/*costs=*/reinterpret_cast<c10::Half*>(costs.data_ptr()),
/*gradients=*/reinterpret_cast<c10::Half*>(gradients.data_ptr()));
break;
}
default: {
break;
STD_TORCH_CHECK(false, "unreachable");
}
};

return std::make_tuple(costs, gradients);
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss", &compute);
void boxed_rnnt_loss(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
STD_TORCH_CHECK(num_args == 7, "num_args must be 7");
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
std::tuple<Tensor, Tensor> res = compute(
/*logits*/to<Tensor>(stack[0]),
/*targets*/to<Tensor>(stack[1]),
/*logit_lengths*/to<Tensor>(stack[2]),
/*target_lengths*/to<Tensor>(stack[3]),
/*blank*/float(to<int64_t>(stack[4])),
/*clamp*/to<double>(stack[5]),
/*fused_log_softmax*/to<bool>(stack[6]));
stack[0] = from(std::get<0>(res));
stack[1] = from(std::get<1>(res));
}

STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("rnnt_loss_forward", &boxed_rnnt_loss);
}

} // namespace cpu
Expand Down
Loading
Loading