Skip to content

Commit 2b6e8ee

Browse files
committed
Use TORCH_BOX in rnnt/cpu.
1 parent 18efb7b commit 2b6e8ee

File tree

2 files changed

+5
-28
lines changed

2 files changed

+5
-28
lines changed

src/libtorchaudio/rnnt/cpu/compute.cpp

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ using torch::stable::Tensor;
1212

1313
// Entry point into RNNT Loss
1414
std::tuple<Tensor, Tensor> compute(
15-
const Tensor& logits,
16-
const Tensor& targets,
17-
const Tensor& logit_lengths,
18-
const Tensor& target_lengths,
15+
Tensor logits,
16+
Tensor targets,
17+
Tensor logit_lengths,
18+
Tensor target_lengths,
1919
int64_t blank,
2020
double clamp,
2121
bool fused_log_softmax = true) {
@@ -147,26 +147,8 @@ std::tuple<Tensor, Tensor> compute(
147147
return std::make_tuple(costs, gradients);
148148
}
149149

150-
void boxed_rnnt_loss(
151-
StableIValue* stack,
152-
uint64_t num_args,
153-
uint64_t num_outputs) {
154-
STD_TORCH_CHECK(num_args == 7, "num_args must be 7");
155-
STD_TORCH_CHECK(num_outputs == 2, "num_outputs must be 2");
156-
std::tuple<Tensor, Tensor> res = compute(
157-
/*logits*/ torch::stable::detail::to<Tensor>(stack[0]),
158-
/*targets*/ torch::stable::detail::to<Tensor>(stack[1]),
159-
/*logit_lengths*/ torch::stable::detail::to<Tensor>(stack[2]),
160-
/*target_lengths*/ torch::stable::detail::to<Tensor>(stack[3]),
161-
/*blank*/ float(torch::stable::detail::to<int64_t>(stack[4])),
162-
/*clamp*/ torch::stable::detail::to<double>(stack[5]),
163-
/*fused_log_softmax*/ torch::stable::detail::to<bool>(stack[6]));
164-
stack[0] = torch::stable::detail::from(std::get<0>(res));
165-
stack[1] = torch::stable::detail::from(std::get<1>(res));
166-
}
167-
168150
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
169-
m.impl("rnnt_loss_forward", &boxed_rnnt_loss);
151+
m.impl("rnnt_loss_forward", TORCH_BOX(&compute));
170152
}
171153

172154
} // namespace cpu

src/libtorchaudio/rnnt/gpu/compute.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,6 @@ std::tuple<Tensor, Tensor> compute(
148148
return std::make_tuple(costs, gradients);
149149
}
150150

151-
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
152-
m.def(
153-
"rnnt_loss_forward(Tensor logits, Tensor targets, Tensor logit_lengths, Tensor target_lengths, int blank, double clamp, bool fused_log_softmax) -> (Tensor, Tensor)");
154-
}
155-
156151
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
157152
m.impl("rnnt_loss_forward", TORCH_BOX(&compute));
158153
}

0 commit comments

Comments
 (0)