Skip to content

Commit 696202e

Browse files
committed
Use stable ABI for compute_betas
1 parent 0ff3b57 commit 696202e

File tree

4 files changed

+102
-47
lines changed

4 files changed

+102
-47
lines changed

src/libtorchaudio/rnnt/compute_alphas.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include <torch/script.h>
21
#include <torch/csrc/stable/library.h>
32

43
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {

src/libtorchaudio/rnnt/compute_betas.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#include <torch/script.h>
1+
#include <torch/csrc/stable/library.h>
22

3-
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
3+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
44
m.def(
55
"rnnt_loss_betas(Tensor logits,"
66
"Tensor targets,"

src/libtorchaudio/rnnt/cpu/compute_alphas.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
2-
#include <torch/script.h>
32
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
43
#include <torch/csrc/inductor/aoti_runtime/utils.h>
54
#include <torch/csrc/stable/library.h>
@@ -63,7 +62,7 @@ RAIIATH compute_alphas(
6362
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace);
6463

6564
AtenTensorHandle float_workspace;
66-
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &int_workspace);
65+
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace);
6766

6867
int64_t float_numel;
6968
aoti_torch_get_numel(float_workspace, &float_numel);

src/libtorchaudio/rnnt/cpu/compute_betas.cpp

Lines changed: 99 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,130 @@
11
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
22
#include <torch/script.h>
3+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
4+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
5+
#include <torch/csrc/stable/library.h>
36

47
namespace torchaudio {
58
namespace rnnt {
69
namespace cpu {
710

8-
torch::Tensor compute_betas(
9-
const torch::Tensor& logits,
10-
const torch::Tensor& targets,
11-
const torch::Tensor& logit_lengths,
12-
const torch::Tensor& target_lengths,
11+
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
12+
13+
RAIIATH compute_betas(
14+
const RAIIATH logits,
15+
const RAIIATH targets,
16+
const RAIIATH logit_lengths,
17+
const RAIIATH target_lengths,
1318
int64_t blank,
1419
double clamp) {
1520
Options options;
16-
options.batchSize_ = logit_lengths.size(0);
17-
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
18-
options.maxSrcLen_ = logits.size(1);
19-
options.maxTgtLen_ = logits.size(2);
20-
options.numTargets_ = logits.size(3);
21+
int64_t tmp;
22+
aoti_torch_get_size(logit_lengths.get(), 0, &tmp);
23+
options.batchSize_ = (int)tmp;
24+
aoti_torch_get_size(target_lengths.get(), 0, &tmp);
25+
options.nHypos_ = (int)tmp;
26+
options.nHypos_ /= options.batchSize_;
27+
aoti_torch_get_size(logits.get(), 1, &tmp);
28+
options.maxSrcLen_ = (int)tmp;
29+
aoti_torch_get_size(logits.get(), 2, &tmp);
30+
options.maxTgtLen_ = (int)tmp;
31+
aoti_torch_get_size(logits.get(), 3, &tmp);
32+
options.numTargets_ = (int)tmp;
2133
options.blank_ = blank;
2234
options.clamp_ = clamp;
2335

24-
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
36+
int32_t logits_device_type;
37+
aoti_torch_get_device_type(logits.get(), &logits_device_type);
38+
AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cpu());
39+
2540
options.device_ = CPU;
2641

27-
torch::Tensor costs = torch::empty(
28-
target_lengths.size(0),
29-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
42+
int32_t logits_device;
43+
aoti_torch_get_device_type(logits.get(), &logits_device);
44+
int32_t logits_device_index;
45+
aoti_torch_get_device_index(logits.get(), &logits_device_index);
46+
int32_t logits_dtype;
47+
aoti_torch_get_dtype(logits.get(), &logits_dtype);
48+
49+
int64_t cost_sizes[1] = {options.batchSize_};
50+
int64_t stride1[1] = {1};
51+
AtenTensorHandle costs;
52+
aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs);
53+
54+
int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_};
55+
int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1};
56+
AtenTensorHandle betas;
57+
aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas);
3058

31-
torch::Tensor betas = torch::zeros(
32-
{options.batchSize_ * options.nHypos_,
33-
options.maxSrcLen_,
34-
options.maxTgtLen_},
35-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
59+
AtenTensorHandle int_workspace;
60+
int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)};
61+
aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace);
3662

37-
torch::Tensor int_workspace = torch::empty(
38-
IntWorkspace::ComputeSizeFromOptions(options),
39-
torch::TensorOptions()
40-
.device(logits.device())
41-
.dtype(torch::ScalarType::Int));
63+
AtenTensorHandle float_workspace;
64+
aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace);
4265

43-
torch::Tensor float_workspace = torch::empty(
44-
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
45-
torch::TensorOptions()
46-
.device(logits.device())
47-
.dtype(torch::ScalarType::Float));
66+
int64_t float_numel;
67+
aoti_torch_get_numel(float_workspace, &float_numel);
68+
void *int_workspace_ptr;
69+
aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr);
70+
void *float_workspace_ptr;
71+
aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr);
72+
int64_t int_numel;
73+
aoti_torch_get_numel(int_workspace, &int_numel);
4874

4975
Workspace<float> workspace(
5076
/*options=*/options,
51-
/*dtype_data=*/float_workspace.data_ptr<float>(),
52-
/*dtype_size=*/float_workspace.numel(),
53-
/*int_data=*/int_workspace.data_ptr<int>(),
54-
/*int_size=*/int_workspace.numel());
77+
/*dtype_data=*/(float*)float_workspace_ptr,
78+
/*dtype_size=*/float_numel,
79+
/*int_data=*/(int*)int_workspace_ptr,
80+
/*int_size=*/int_numel);
81+
82+
void *logit_ptr;
83+
aoti_torch_get_data_ptr(logits.get(), &logit_ptr);
84+
85+
void *target_ptr;
86+
aoti_torch_get_data_ptr(targets.get(), &target_ptr);
87+
88+
void *logit_len_ptr;
89+
aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr);
90+
91+
void *target_len_ptr;
92+
aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr);
93+
94+
void *beta_ptr;
95+
aoti_torch_get_data_ptr(betas, &beta_ptr);
96+
97+
void *cost_ptr;
98+
aoti_torch_get_data_ptr(costs, &cost_ptr);
5599

56100
// Only support float, this is mainly to enable easy
57101
// unit-testing
58102
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
59103
/*workspace=*/workspace,
60-
/*logits=*/logits.data_ptr<float>(),
61-
/*targets=*/targets.data_ptr<int>(),
62-
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
63-
/*target_lengths=*/target_lengths.data_ptr<int>(),
64-
/*costs=*/costs.data_ptr<float>(),
65-
/*betas=*/betas.data_ptr<float>());
66-
return betas;
104+
/*logits=*/(float*)logit_ptr,
105+
/*targets=*/(int*)target_ptr,
106+
/*logit_lengths=*/(int*)logit_len_ptr,
107+
/*target_lengths=*/(int*)target_len_ptr,
108+
/*costs=*/(float*)cost_ptr,
109+
/*betas=*/(float*)beta_ptr);
110+
return RAIIATH(betas);
111+
}
112+
113+
114+
void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
115+
RAIIATH t1(to<AtenTensorHandle>(stack[0]));
116+
RAIIATH t2(to<AtenTensorHandle>(stack[1]));
117+
RAIIATH t3(to<AtenTensorHandle>(stack[2]));
118+
RAIIATH t4(to<AtenTensorHandle>(stack[3]));
119+
int64_t blank = to<int64_t>(stack[4]);
120+
double clamp = to<double>(stack[5]);
121+
RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4),
122+
blank, clamp);
123+
stack[0] = from(result.release());
67124
}
68125

69-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
70-
m.impl("rnnt_loss_betas", &compute_betas);
126+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
127+
m.impl("rnnt_loss_betas", &boxed_compute_betas);
71128
}
72129

73130
} // namespace cpu

0 commit comments

Comments
 (0)