|
1 | 1 | #include <c10/cuda/CUDAStream.h>
|
2 | 2 | #include <libtorchaudio/rnnt/gpu/gpu_transducer.h>
|
3 |
| -#include <torch/types.h> |
| 3 | +#include <torch/csrc/inductor/aoti_torch/c/shim.h> |
| 4 | +#include <torch/csrc/inductor/aoti_runtime/utils.h> |
| 5 | +#include <torch/csrc/stable/library.h> |
4 | 6 |
|
5 | 7 | namespace torchaudio {
|
6 | 8 | namespace rnnt {
|
7 | 9 | namespace gpu {
|
8 | 10 |
|
9 |
| -torch::Tensor compute_betas( |
10 |
| - const torch::Tensor& logits, |
11 |
| - const torch::Tensor& targets, |
12 |
| - const torch::Tensor& logit_lengths, |
13 |
| - const torch::Tensor& target_lengths, |
| 11 | +using RAIIATH = torch::aot_inductor::RAIIAtenTensorHandle; |
| 12 | + |
| 13 | + |
| 14 | +RAIIATH compute_betas( |
| 15 | + const RAIIATH logits, |
| 16 | + const RAIIATH targets, |
| 17 | + const RAIIATH logit_lengths, |
| 18 | + const RAIIATH target_lengths, |
14 | 19 | int64_t blank,
|
15 | 20 | double clamp) {
|
16 |
| - Options options; |
17 |
| - options.batchSize_ = logit_lengths.size(0); |
18 |
| - options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0); |
19 |
| - options.maxSrcLen_ = logits.size(1); |
20 |
| - options.maxTgtLen_ = logits.size(2); |
21 |
| - options.numTargets_ = logits.size(3); |
22 |
| - options.blank_ = blank; |
23 |
| - options.clamp_ = clamp; |
24 |
| - |
25 |
| - TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CUDA); |
26 |
| - options.stream_ = at::cuda::getCurrentCUDAStream(); |
27 |
| - cudaSetDevice(logits.get_device()); |
| 21 | + Options options; |
| 22 | + int64_t tmp; |
| 23 | + aoti_torch_get_size(logit_lengths.get(), 0, &tmp); |
| 24 | + options.batchSize_ = (int)tmp; |
| 25 | + aoti_torch_get_size(target_lengths.get(), 0, &tmp); |
| 26 | + options.nHypos_ = (int)tmp; |
| 27 | + options.nHypos_ /= options.batchSize_; |
| 28 | + aoti_torch_get_size(logits.get(), 1, &tmp); |
| 29 | + options.maxSrcLen_ = (int)tmp; |
| 30 | + aoti_torch_get_size(logits.get(), 2, &tmp); |
| 31 | + options.maxTgtLen_ = (int)tmp; |
| 32 | + aoti_torch_get_size(logits.get(), 3, &tmp); |
| 33 | + options.numTargets_ = (int)tmp; |
| 34 | + options.blank_ = blank; |
| 35 | + options.clamp_ = clamp; |
| 36 | + |
| 37 | + int32_t logits_device_type; |
| 38 | + aoti_torch_get_device_type(logits.get(), &logits_device_type); |
| 39 | + AOTI_TORCH_CHECK(logits_device_type == aoti_torch_device_type_cuda()); |
| 40 | + |
| 41 | + |
| 42 | + int32_t logits_device; |
| 43 | + aoti_torch_get_device_type(logits.get(), &logits_device); |
| 44 | + int32_t logits_device_index; |
| 45 | + aoti_torch_get_device_index(logits.get(), &logits_device_index); |
| 46 | + int32_t logits_dtype; |
| 47 | + aoti_torch_get_dtype(logits.get(), &logits_dtype); |
| 48 | + |
| 49 | + aoti_torch_get_current_cuda_stream(logits_device_index, &options.stream_); |
| 50 | + cudaSetDevice(logits_device) |
28 | 51 | options.device_ = GPU;
|
29 | 52 |
|
30 |
| - torch::Tensor costs = torch::empty( |
31 |
| - target_lengths.size(0), |
32 |
| - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); |
| 53 | + int64_t cost_sizes[1] = {options.batchSize_}; |
| 54 | + int64_t stride1[1] = {1}; |
| 55 | + AtenTensorHandle costs; |
| 56 | + aoti_torch_empty_strided(1, cost_sizes, stride1, logits_dtype, logits_device, logits_device_index, &costs); |
33 | 57 |
|
34 |
| - torch::Tensor betas = torch::zeros( |
35 |
| - {options.batchSize_ * options.nHypos_, |
36 |
| - options.maxSrcLen_, |
37 |
| - options.maxTgtLen_}, |
38 |
| - torch::TensorOptions().device(logits.device()).dtype(logits.dtype())); |
| 58 | + int64_t betas_sizes[3] = {options.batchSize_ * options.nHypos_, options.maxSrcLen_, options.maxTgtLen_}; |
| 59 | + int64_t betas_strides[3] = {options.maxSrcLen_ * options.maxTgtLen_, options.maxTgtLen_, 1}; |
| 60 | + AtenTensorHandle betas; |
| 61 | + aoti_torch_empty_strided(3, betas_sizes, betas_strides, logits_dtype, logits_device, logits_device_index, &betas); |
39 | 62 |
|
40 |
| - torch::Tensor int_workspace = torch::empty( |
41 |
| - IntWorkspace::ComputeSizeFromOptions(options), |
42 |
| - torch::TensorOptions() |
43 |
| - .device(logits.device()) |
44 |
| - .dtype(torch::ScalarType::Int)); |
| 63 | + AtenTensorHandle int_workspace; |
| 64 | + int64_t w_sizes[1] = {IntWorkspace::ComputeSizeFromOptions(options)}; |
| 65 | + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_int32(), logits_device, logits_device_index, &int_workspace); |
45 | 66 |
|
46 |
| - torch::Tensor float_workspace = torch::empty( |
47 |
| - DtypeWorkspace<float>::ComputeSizeFromOptions(options), |
48 |
| - torch::TensorOptions() |
49 |
| - .device(logits.device()) |
50 |
| - .dtype(torch::ScalarType::Float)); |
| 67 | + AtenTensorHandle float_workspace; |
| 68 | + aoti_torch_empty_strided(1, w_sizes, stride1, aoti_torch_dtype_float32(), logits_device, logits_device_index, &float_workspace); |
| 69 | + |
| 70 | + int64_t float_numel; |
| 71 | + aoti_torch_get_numel(float_workspace, &float_numel); |
| 72 | + void *int_workspace_ptr; |
| 73 | + aoti_torch_get_data_ptr(int_workspace, &int_workspace_ptr); |
| 74 | + void *float_workspace_ptr; |
| 75 | + aoti_torch_get_data_ptr(float_workspace, &float_workspace_ptr); |
| 76 | + int64_t int_numel; |
| 77 | + aoti_torch_get_numel(int_workspace, &int_numel); |
51 | 78 |
|
52 | 79 | Workspace<float> workspace(
|
53 | 80 | /*options=*/options,
|
54 |
| - /*dtype_data=*/float_workspace.data_ptr<float>(), |
55 |
| - /*dtype_size=*/float_workspace.numel(), |
56 |
| - /*int_data=*/int_workspace.data_ptr<int>(), |
57 |
| - /*int_size=*/int_workspace.numel()); |
| 81 | + /*dtype_data=*/(float*)float_workspace_ptr, |
| 82 | + /*dtype_size=*/float_numel, |
| 83 | + /*int_data=*/(int*)int_workspace_ptr, |
| 84 | + /*int_size=*/int_numel); |
| 85 | + |
| 86 | + void *logit_ptr; |
| 87 | + aoti_torch_get_data_ptr(logits.get(), &logit_ptr); |
| 88 | + |
| 89 | + void *target_ptr; |
| 90 | + aoti_torch_get_data_ptr(targets.get(), &target_ptr); |
| 91 | + |
| 92 | + void *logit_len_ptr; |
| 93 | + aoti_torch_get_data_ptr(logit_lengths.get(), &logit_len_ptr); |
| 94 | + |
| 95 | + void *target_len_ptr; |
| 96 | + aoti_torch_get_data_ptr(target_lengths.get(), &target_len_ptr); |
| 97 | + |
| 98 | + void *beta_ptr; |
| 99 | + aoti_torch_get_data_ptr(betas, &beta_ptr); |
| 100 | + |
| 101 | + void *cost_ptr; |
| 102 | + aoti_torch_get_data_ptr(costs, &cost_ptr); |
58 | 103 |
|
59 | 104 | // Only support float, this is mainly to enable easy
|
60 | 105 | // unit-testing
|
61 | 106 | ComputeBetas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
|
62 | 107 | /*workspace=*/workspace,
|
63 |
| - /*logits=*/logits.data_ptr<float>(), |
64 |
| - /*targets=*/targets.data_ptr<int>(), |
65 |
| - /*logit_lengths=*/logit_lengths.data_ptr<int>(), |
66 |
| - /*target_lengths=*/target_lengths.data_ptr<int>(), |
67 |
| - /*costs=*/costs.data_ptr<float>(), |
68 |
| - /*betas=*/betas.data_ptr<float>()); |
69 |
| - return betas; |
| 108 | + /*logits=*/(float*)logit_ptr, |
| 109 | + /*targets=*/(int*)target_ptr, |
| 110 | + /*logit_lengths=*/(int*)logit_len_ptr, |
| 111 | + /*target_lengths=*/(int*)target_len_ptr, |
| 112 | + /*costs=*/(float*)cost_ptr, |
| 113 | + /*betas=*/(float*)beta_ptr); |
| 114 | + return RAIIATH(betas); |
| 115 | +} |
| 116 | + |
| 117 | +void boxed_compute_betas(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { |
| 118 | + RAIIATH t1(to<AtenTensorHandle>(stack[0])); |
| 119 | + RAIIATH t2(to<AtenTensorHandle>(stack[1])); |
| 120 | + RAIIATH t3(to<AtenTensorHandle>(stack[2])); |
| 121 | + RAIIATH t4(to<AtenTensorHandle>(stack[3])); |
| 122 | + int64_t blank = to<int64_t>(stack[4]); |
| 123 | + double clamp = to<double>(stack[5]); |
| 124 | + RAIIATH result = compute_betas(std::move(t1), std::move(t2), std::move(t3), std::move(t4), |
| 125 | + blank, clamp); |
| 126 | + stack[0] = from(result.release()); |
70 | 127 | }
|
71 | 128 |
|
72 |
| -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { |
73 |
| - m.impl("rnnt_loss_betas", &compute_betas); |
| 129 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { |
| 130 | + m.impl("rnnt_loss_betas", &boxed_compute_betas); |
74 | 131 | }
|
75 | 132 |
|
76 | 133 | } // namespace gpu
|
|
0 commit comments