Skip to content

Commit d4dd7bd

Browse files
committed
Convert cpu/compute_alphas to stable API
1 parent 9b9dc25 commit d4dd7bd

File tree

2 files changed

+97
-39
lines changed

2 files changed

+97
-39
lines changed

src/libtorchaudio/rnnt/compute_alphas.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/script.h>
2+
#include <torch/csrc/stable/library.h>
23

3-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
4+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
45
m.def(
56
"rnnt_loss_alphas(Tensor logits,"
67
"Tensor targets,"

src/libtorchaudio/rnnt/cpu/compute_alphas.cpp

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,125 @@
11
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
22
#include <torch/script.h>
3+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
4+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
5+
#include <torch/csrc/stable/library.h>
6+
7+
// TODO:
8+
// Are the StableIValue AtenTensorHandles reference counted at all?
9+
// Why do we call release() on returned arguments?
310

411
namespace torchaudio {
512
namespace rnnt {
613
namespace cpu {
714

8-
torch::Tensor compute_alphas(
9-
const torch::Tensor& logits,
10-
const torch::Tensor& targets,
11-
const torch::Tensor& logit_lengths,
12-
const torch::Tensor& target_lengths,
15+
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
16+
17+
RAIIATH compute_alphas(
18+
const RAIIATH logits,
19+
const RAIIATH targets,
20+
const RAIIATH logit_lengths,
21+
const RAIIATH target_lengths,
1322
int64_t blank,
1423
double clamp) {
1524
Options options;
16-
options.batchSize_ = logit_lengths.size(0);
17-
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
18-
options.maxSrcLen_ = logits.size(1);
19-
options.maxTgtLen_ = logits.size(2);
20-
options.numTargets_ = logits.size(3);
25+
int64_t tmp;
26+
aoti_torch_get_size(logit_lengths.get(), 0, &tmp);
27+
options.batchSize_ = (int)tmp;
28+
aoti_torch_get_size(target_lengths.get(), 0, &tmp);
29+
options.nHypos_ = (int)tmp;
30+
options.nHypos_ /= options.batchSize_;
31+
aoti_torch_get_size(logits.get(), 1, &tmp);
32+
options.maxSrcLen_ = (int)tmp;
33+
aoti_torch_get_size(logits.get(), 2, &tmp);
34+
options.maxTgtLen_ = (int)tmp;
35+
aoti_torch_get_size(logits.get(), 3, &tmp);
36+
options.numTargets_ = (int)tmp;
2137
options.blank_ = blank;
2238
options.clamp_ = clamp;
2339

24-
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
40+
// TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
2541
options.device_ = CPU;
2642

27-
torch::Tensor alphas = torch::zeros(
28-
{options.batchSize_ * options.nHypos_,
29-
options.maxSrcLen_,
30-
options.maxTgtLen_},
31-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
3243

33-
torch::Tensor int_workspace = torch::empty(
34-
IntWorkspace::ComputeSizeFromOptions(options),
35-
torch::TensorOptions()
36-
.device(logits.device())
37-
.dtype(torch::ScalarType::Int));
44+
int32_t logits_device;
45+
aoti_torch_get_device_type(logits.get(), &logits_device);
46+
int32_t logits_device_index;
47+
aoti_torch_get_device_index(logits.get(), &logits_device_index);
48+
int32_t logits_dtype;
49+
aoti_torch_get_dtype(logits.get(), &logits_dtype);
50+
51+
int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_};
52+
int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1};
53+
54+
AtenTensorHandle alphas;
55+
aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas);
56+
aoti_torch_zero_(alphas);
3857

39-
torch::Tensor float_workspace = torch::empty(
40-
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
41-
torch::TensorOptions()
42-
.device(logits.device())
43-
.dtype(torch::ScalarType::Float));
58+
AtenTensorHandle int_workspace;
59+
int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)};
60+
int64_t strides[1] = {1};
61+
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace);
62+
63+
AtenTensorHandle float_workspace;
64+
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &int_workspace);
65+
66+
int64_t float_numel;
67+
aoti_torch_get_numel(float_workspace, &float_numel);
68+
void *int_workspace_ptr;
69+
aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr);
70+
void *float_workspace_ptr;
71+
aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr);
72+
int64_t int_numel;
73+
aoti_torch_get_numel(int_workspace, &int_numel);
4474

4575
Workspace<float> workspace(
4676
/*options=*/options,
47-
/*dtype_data=*/float_workspace.data_ptr<float>(),
48-
/*dtype_size=*/float_workspace.numel(),
49-
/*int_data=*/int_workspace.data_ptr<int>(),
50-
/*int_size=*/int_workspace.numel());
77+
/*dtype_data=*/(float*)float_workspace_ptr,
78+
/*dtype_size=*/float_numel,
79+
/*int_data=*/(int*)int_workspace_ptr,
80+
/*int_size=*/int_numel);
81+
82+
void *logit_ptr;
83+
aoti_torch_get_data_ptr(logits.get(), &logit_ptr);
84+
85+
void *target_ptr;
86+
aoti_torch_get_data_ptr(targets.get(), &target_ptr);
87+
88+
void *logit_len_ptr;
89+
aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr);
90+
91+
void *target_len_ptr;
92+
aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr);
93+
94+
void *alpha_ptr;
95+
aoti_torch_get_data_ptr(alphas, &alpha_ptr);
5196

5297
// Only support float, this is mainly to enable easy
5398
// unit-testing
5499
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
55100
/*workspace=*/workspace,
56-
/*logits=*/logits.data_ptr<float>(),
57-
/*targets=*/targets.data_ptr<int>(),
58-
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
59-
/*target_lengths=*/target_lengths.data_ptr<int>(),
60-
/*alphas=*/alphas.data_ptr<float>());
61-
return alphas;
101+
/*logits=*/(float*)logit_ptr,
102+
/*targets=*/(int*)target_ptr,
103+
/*logit_lengths=*/(int*)logit_len_ptr,
104+
/*target_lengths=*/(int*)target_len_ptr,
105+
/*alphas=*/(float*)alpha_ptr);
106+
return RAIIATH(alphas);
107+
}
108+
109+
void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
110+
RAIIATH t1(to<AtenTensorHandle>(stack[0]));
111+
RAIIATH t2(to<AtenTensorHandle>(stack[1]));
112+
RAIIATH t3(to<AtenTensorHandle>(stack[2]));
113+
RAIIATH t4(to<AtenTensorHandle>(stack[3]));
114+
int64_t blank = to<int64_t>(stack[4]);
115+
double clamp = to<double>(stack[5]);
116+
RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4),
117+
blank, clamp);
118+
stack[0] = from(result.release());
62119
}
63120

64-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
65-
m.impl("rnnt_loss_alphas", &compute_alphas);
121+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
122+
m.impl("rnnt_loss_alphas", &boxed_compute_alphas);
66123
}
67124

68125
} // namespace cpu

0 commit comments

Comments
 (0)