Skip to content

Commit 32c80da

Browse files
committed
Use stable ABI for cuda version of compute_alphas
1 parent 696202e commit 32c80da

File tree

1 file changed

+95
-41
lines changed

1 file changed

+95
-41
lines changed

src/libtorchaudio/rnnt/gpu/compute_alphas.cu

Lines changed: 95 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,125 @@
11
#include <c10/cuda/CUDAStream.h>
22
#include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
3-
#include <torch/types.h>
3+
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
4+
#include <torch/csrc/inductor/aoti_runtime/utils.h>
5+
#include <torch/csrc/stable/library.h>
46

57
namespace torchaudio {
68
namespace rnnt {
79
namespace gpu {
810

9-
torch::Tensor compute_alphas(
10-
const torch::Tensor& logits,
11-
const torch::Tensor& targets,
12-
const torch::Tensor& logit_lengths,
13-
const torch::Tensor& target_lengths,
11+
using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle;
12+
13+
RAIIATH compute_alphas(
14+
const RAIIATH logits,
15+
const RAIIATH targets,
16+
const RAIIATH logit_lengths,
17+
const RAIIATH target_lengths,
1418
int64_t blank,
1519
double clamp) {
1620
Options options;
17-
options.batchSize_ = logit_lengths.size(0);
18-
options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
19-
options.maxSrcLen_ = logits.size(1);
20-
options.maxTgtLen_ = logits.size(2);
21-
options.numTargets_ = logits.size(3);
21+
int64_t tmp;
22+
aoti_torch_get_size(logit_lengths.get(), 0, &tmp);
23+
options.batchSize_ = (int)tmp;
24+
aoti_torch_get_size(target_lengths.get(), 0, &tmp);
25+
options.nHypos_ = (int)tmp;
26+
options.nHypos_ /= options.batchSize_;
27+
aoti_torch_get_size(logits.get(), 1, &tmp);
28+
options.maxSrcLen_ = (int)tmp;
29+
aoti_torch_get_size(logits.get(), 2, &tmp);
30+
options.maxTgtLen_ = (int)tmp;
31+
aoti_torch_get_size(logits.get(), 3, &tmp);
32+
options.numTargets_ = (int)tmp;
2233
options.blank_ = blank;
2334
options.clamp_ = clamp;
2435

25-
TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA);
26-
options.stream_ = at::cuda::getCurrentCUDAStream();
27-
cudaSetDevice(logits.get_device());
36+
int32_t logits_device_type;
37+
aoti_torch_get_device_type(logits.get(), &logits_device_type);
38+
AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda());
39+
40+
int32_t logits_device;
41+
aoti_torch_get_device_type(logits.get(), &logits_device);
42+
int32_t logits_device_index;
43+
aoti_torch_get_device_index(logits.get(), &logits_device_index);
44+
int32_t logits_dtype;
45+
aoti_torch_get_dtype(logits.get(), &logits_dtype);
46+
47+
aoti_torch_get_current_cuda_stream(logits_device_index, &options.stream_);
48+
cudaSetDevice(logits_device)
2849
options.device_ = GPU;
2950

30-
torch::Tensor alphas = torch::zeros(
31-
{options.batchSize_ * options.nHypos_,
32-
options.maxSrcLen_,
33-
options.maxTgtLen_},
34-
torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
51+
int64_t param_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_};
52+
int64_t param_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1};
3553

36-
torch::Tensor int_workspace = torch::empty(
37-
IntWorkspace::ComputeSizeFromOptions(options),
38-
torch::TensorOptions()
39-
.device(logits.device())
40-
.dtype(torch::ScalarType::Int));
54+
AtenTensorHandle alphas;
55+
aoti_torch_empty_strided(3, param_sizes, param_strides, logits_dtype, logits_device, logits_device_index, &alphas);
56+
aoti_torch_zero_(alphas);
4157

42-
torch::Tensor float_workspace = torch::empty(
43-
DtypeWorkspace<float>::ComputeSizeFromOptions(options),
44-
torch::TensorOptions()
45-
.device(logits.device())
46-
.dtype(torch::ScalarType::Float));
58+
AtenTensorHandle int_workspace;
59+
int64_t sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)};
60+
int64_t strides[1] = {1};
61+
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace);
62+
63+
AtenTensorHandle float_workspace;
64+
aoti_torch_empty_strided(1, sizes, strides, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace);
65+
66+
int64_t float_numel;
67+
aoti_torch_get_numel(float_workspace, &float_numel);
68+
void *int_workspace_ptr;
69+
aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr);
70+
void *float_workspace_ptr;
71+
aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr);
72+
int64_t int_numel;
73+
aoti_torch_get_numel(int_workspace, &int_numel);
4774

4875
Workspace<float> workspace(
4976
/*options=*/options,
50-
/*dtype_data=*/float_workspace.data_ptr<float>(),
51-
/*dtype_size=*/float_workspace.numel(),
52-
/*int_data=*/int_workspace.data_ptr<int>(),
53-
/*int_size=*/int_workspace.numel());
77+
/*dtype_data=*/(float*)float_workspace_ptr,
78+
/*dtype_size=*/float_numel,
79+
/*int_data=*/(int*)int_workspace_ptr,
80+
/*int_size=*/int_numel);
81+
82+
void *logit_ptr;
83+
aoti_torch_get_data_ptr(logits.get(), &logit_ptr);
84+
85+
void *target_ptr;
86+
aoti_torch_get_data_ptr(targets.get(), &target_ptr);
87+
88+
void *logit_len_ptr;
89+
aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr);
90+
91+
void *target_len_ptr;
92+
aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr);
93+
94+
void *alpha_ptr;
95+
aoti_torch_get_data_ptr(alphas, &alpha_ptr);
5496

5597
// Only support float, this is mainly to enable easy
5698
// unit-testing
5799
ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
58100
/*workspace=*/workspace,
59-
/*logits=*/logits.data_ptr<float>(),
60-
/*targets=*/targets.data_ptr<int>(),
61-
/*logit_lengths=*/logit_lengths.data_ptr<int>(),
62-
/*target_lengths=*/target_lengths.data_ptr<int>(),
63-
/*alphas=*/alphas.data_ptr<float>());
64-
return alphas;
101+
/*logits=*/(float*)logit_ptr,
102+
/*targets=*/(int*)target_ptr,
103+
/*logit_lengths=*/(int*)logit_len_ptr,
104+
/*target_lengths=*/(int*)target_len_ptr,
105+
/*alphas=*/(float*)alpha_ptr);
106+
return RAIIATH(alphas);
107+
}
108+
109+
void boxed_compute_alphas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) {
110+
RAIIATH t1(to<AtenTensorHandle>(stack[0]));
111+
RAIIATH t2(to<AtenTensorHandle>(stack[1]));
112+
RAIIATH t3(to<AtenTensorHandle>(stack[2]));
113+
RAIIATH t4(to<AtenTensorHandle>(stack[3]));
114+
int64_t blank = to<int64_t>(stack[4]);
115+
double clamp = to<double>(stack[5]);
116+
RAIIATH result = compute_alphas(std::move(t1), std::move(t2), std::move(t3), std::move(t4),
117+
blank, clamp);
118+
stack[0] = from(result.release());
65119
}
66120

67-
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
68-
m.impl("rnnt_loss_alphas", &compute_alphas);
121+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
122+
m.impl("rnnt_loss_alphas", &boxed_compute_alphas);
69123
}
70124

71125
} // namespace gpu

0 commit comments

Comments
 (0)