Skip to content

[STABLE ABI PORT] Use stable ABI for RNNT #3977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
77d4421
Port autograd code for rnnt
samanklesaria Jul 11, 2025
725c74e
Correct rnnt calling arguments
samanklesaria Jul 11, 2025
9717651
Disable torchscript checks
samanklesaria Jul 11, 2025
2b88250
Restrict disabling of torchscript tests
samanklesaria Jul 11, 2025
116de6f
Remove leftover line
samanklesaria Jul 11, 2025
003b3a9
Remove unnecessary backward code
samanklesaria Jul 14, 2025
7727ad7
Move rnnt_loss_forward to compute.cpp
samanklesaria Jul 14, 2025
9b9dc25
Remove autograd rnnt in cmakelists
samanklesaria Jul 14, 2025
d4dd7bd
Convert cpu/compute_alphas to stable API
samanklesaria Jul 15, 2025
0ff3b57
Add back device type check
samanklesaria Jul 15, 2025
696202e
Use stable ABI for compute_betas
samanklesaria Jul 15, 2025
32c80da
Use stable ABI for cuda version of compute_alphas
samanklesaria Jul 15, 2025
5490cd3
Use stable ABI for cuda version of compute_betas
samanklesaria Jul 15, 2025
89d480b
Add missing semicolon
samanklesaria Jul 17, 2025
7f11d1d
Cast to void pointer pointer
samanklesaria Jul 17, 2025
cc592e0
Use stable ABI for compute
samanklesaria Jul 17, 2025
180a393
Attempt to fix stable ABI calls
samanklesaria Jul 17, 2025
577ff0c
Use stable Tensor interface
samanklesaria Jul 31, 2025
926ca7d
Correct use of stable Tensor
samanklesaria Jul 31, 2025
9f75cdf
WIP
samanklesaria Jul 31, 2025
526e74d
Remove mytest
samanklesaria Jul 31, 2025
1d5f9ef
Fix float size calculation
samanklesaria Jul 31, 2025
bad7309
Remove debugging printfs
samanklesaria Jul 31, 2025
af1c91d
Merge branch 'main' into stable_rnnt
samanklesaria Jul 31, 2025
317b964
Fix size bug for rnnt gpu
samanklesaria Jul 31, 2025
8bdac23
Remove alphas and betas for rnnt
samanklesaria Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/libtorchaudio/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,13 @@ if(BUILD_RNNT)
list(
APPEND
sources
rnnt/cpu/compute_alphas.cpp
rnnt/cpu/compute_betas.cpp
rnnt/cpu/compute.cpp
rnnt/compute_alphas.cpp
rnnt/compute_betas.cpp
rnnt/compute.cpp
)
if (USE_CUDA)
list(
APPEND
sources
rnnt/gpu/compute_alphas.cu
rnnt/gpu/compute_betas.cu
rnnt/gpu/compute.cu
)
endif()
Expand Down
28 changes: 4 additions & 24 deletions src/libtorchaudio/rnnt/compute.cpp
Original file line number Diff line number Diff line change
@@ -1,34 +1,14 @@
#include <libtorchaudio/rnnt/compute.h>
#include <torch/csrc/stable/library.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax) -> (Tensor, Tensor?)");
m.def("torchaudio::rnnt_loss_forward", &rnnt_loss);
"bool fused_log_softmax) -> (Tensor, Tensor)");
}
12 changes: 0 additions & 12 deletions src/libtorchaudio/rnnt/compute.h

This file was deleted.

11 changes: 0 additions & 11 deletions src/libtorchaudio/rnnt/compute_alphas.cpp

This file was deleted.

11 changes: 0 additions & 11 deletions src/libtorchaudio/rnnt/compute_betas.cpp

This file was deleted.

Loading
Loading