Skip to content

Commit 5490cd3

Browse files
committed
Use stable ABI for cuda version of compute_betas
1 parent 32c80da commit 5490cd3

File tree

1 file changed

+106
-49
lines changed

1 file changed

+106
-49
lines changed

src/libtorchaudio/rnnt/gpu/compute_betas.cu

Lines changed: 106 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,133 @@
11
#include <c10/cuda/CUDAStream.h>
22
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
3-
#include <torch/types.h>
3+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
4+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
5+
#include <torch/csrc/stable/library.h>
46

57
namespace torchaudio {
68
namespace rnnt {
79
namespace gpu {
810

9-
torch::Tensor compute_betas(
10-
const torch::Tensor& logits,
11-
const torch::Tensor& targets,
12-
const torch::Tensor& logit_lengths,
13-
const torch::Tensor& target_lengths,
11+
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
12+
13+
14+
RAIIATH compute_betas(
15+
const RAIIATH logits,
16+
const RAIIATH targets,
17+
const RAIIATH logit_lengths,
18+
const RAIIATH target_lengths,
1419
int64_t blank,
1520
double clamp) {
16-
Options options;
17-
options.batchSize_ = logit_lengths.size(0);
18-
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
19-
options.maxSrcLen_ = logits.size(1);
20-
options.maxTgtLen_ = logits.size(2);
21-
options.numTargets_ = logits.size(3);
22-
options.blank_ = blank;
23-
options.clamp_ = clamp;
24-
25-
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
26-
options.stream_ = at::cuda::getCurrentCUDAStream();
27-
cudaSetDevice(logits.get_device());
21+
Options options;
22+
int64_t tmp;
23+
aoti_torch_get_size(logit_lengths.get(), 0, &tmp);
24+
options.batchSize_ = (int)tmp;
25+
aoti_torch_get_size(target_lengths.get(), 0, &tmp);
26+
options.nHypos_ = (int)tmp;
27+
options.nHypos_ /= options.batchSize_;
28+
aoti_torch_get_size(logits.get(), 1, &tmp);
29+
options.maxSrcLen_ = (int)tmp;
30+
aoti_torch_get_size(logits.get(), 2, &tmp);
31+
options.maxTgtLen_ = (int)tmp;
32+
aoti_torch_get_size(logits.get(), 3, &tmp);
33+
options.numTargets_ = (int)tmp;
34+
options.blank_ = blank;
35+
options.clamp_ = clamp;
36+
37+
int32_t logits_device_type;
38+
aoti_torch_get_device_type(logits.get(), &logits_device_type);
39+
AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda());
40+
41+
42+
int32_t logits_device;
43+
aoti_torch_get_device_type(logits.get(), &logits_device);
44+
int32_t logits_device_index;
45+
aoti_torch_get_device_index(logits.get(), &logits_device_index);
46+
int32_t logits_dtype;
47+
aoti_torch_get_dtype(logits.get(), &logits_dtype);
48+
49+
aoti_torch_get_current_cuda_stream(logits_device_index, &options.stream_);
50+
cudaSetDevice(logits_device)
2851
options.device_ = GPU;
2952

30-
torch::Tensor costs = torch::empty(
31-
target_lengths.size(0),
32-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
53+
int64_t cost_sizes[1] = {options.batchSize_};
54+
int64_t stride1[1] = {1};
55+
AtenTensorHandle costs;
56+
aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs);
3357

34-
torch::Tensor betas = torch::zeros(
35-
{options.batchSize_ * options.nHypos_,
36-
options.maxSrcLen_,
37-
options.maxTgtLen_},
38-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
58+
int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_};
59+
int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1};
60+
AtenTensorHandle betas;
61+
aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas);
3962

40-
torch::Tensor int_workspace = torch::empty(
41-
IntWorkspace::ComputeSizeFromOptions(options),
42-
torch::TensorOptions()
43-
.device(logits.device())
44-
.dtype(torch::ScalarType::Int));
63+
AtenTensorHandle int_workspace;
64+
int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)};
65+
aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace);
4566

46-
torch::Tensor float_workspace = torch::empty(
47-
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
48-
torch::TensorOptions()
49-
.device(logits.device())
50-
.dtype(torch::ScalarType::Float));
67+
AtenTensorHandle float_workspace;
68+
aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace);
69+
70+
int64_t float_numel;
71+
aoti_torch_get_numel(float_workspace, &float_numel);
72+
void *int_workspace_ptr;
73+
aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr);
74+
void *float_workspace_ptr;
75+
aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr);
76+
int64_t int_numel;
77+
aoti_torch_get_numel(int_workspace, &int_numel);
5178

5279
Workspace<float> workspace(
5380
/*options=*/options,
54-
/*dtype_data=*/float_workspace.data_ptr<float>(),
55-
/*dtype_size=*/float_workspace.numel(),
56-
/*int_data=*/int_workspace.data_ptr<int>(),
57-
/*int_size=*/int_workspace.numel());
81+
/*dtype_data=*/(float*)float_workspace_ptr,
82+
/*dtype_size=*/float_numel,
83+
/*int_data=*/(int*)int_workspace_ptr,
84+
/*int_size=*/int_numel);
85+
86+
void *logit_ptr;
87+
aoti_torch_get_data_ptr(logits.get(), &logit_ptr);
88+
89+
void *target_ptr;
90+
aoti_torch_get_data_ptr(targets.get(), &target_ptr);
91+
92+
void *logit_len_ptr;
93+
aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr);
94+
95+
void *target_len_ptr;
96+
aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr);
97+
98+
void *beta_ptr;
99+
aoti_torch_get_data_ptr(betas, &beta_ptr);
100+
101+
void *cost_ptr;
102+
aoti_torch_get_data_ptr(costs, &cost_ptr);
58103

59104
// Only support float, this is mainly to enable easy
60105
// unit-testing
61106
ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
62107
/*workspace=*/workspace,
63-
/*logits=*/logits.data_ptr<float>(),
64-
/*targets=*/targets.data_ptr<int>(),
65-
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
66-
/*target_lengths=*/target_lengths.data_ptr<int>(),
67-
/*costs=*/costs.data_ptr<float>(),
68-
/*betas=*/betas.data_ptr<float>());
69-
return betas;
108+
/*logits=*/(float*)logit_ptr,
109+
/*targets=*/(int*)target_ptr,
110+
/*logit_lengths=*/(int*)logit_len_ptr,
111+
/*target_lengths=*/(int*)target_len_ptr,
112+
/*costs=*/(float*)cost_ptr,
113+
/*betas=*/(float*)beta_ptr);
114+
return RAIIATH(betas);
115+
}
116+
117+
void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
118+
RAIIATH t1(to<AtenTensorHandle>(stack[0]));
119+
RAIIATH t2(to<AtenTensorHandle>(stack[1]));
120+
RAIIATH t3(to<AtenTensorHandle>(stack[2]));
121+
RAIIATH t4(to<AtenTensorHandle>(stack[3]));
122+
int64_t blank = to<int64_t>(stack[4]);
123+
double clamp = to<double>(stack[5]);
124+
RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4),
125+
blank, clamp);
126+
stack[0] = from(result.release());
70127
}
71128

72-
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
73-
m.impl("rnnt_loss_betas", &compute_betas);
129+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
130+
m.impl("rnnt_loss_betas", &boxed_compute_betas);
74131
}
75132

76133
} // namespace gpu

0 commit comments

Comments
 (0)