|
| 1 | +#include <ATen/Functions.h> |
| 2 | +#include <torch/torch.h> |
| 3 | +#include <torch/extension.h> |
| 4 | +#include <vector> |
| 5 | +#include <cuda.h> |
| 6 | +#include <cuda_runtime.h> |
| 7 | +#include <vector> |
| 8 | +#include <c10/cuda/CUDAException.h> |
| 9 | +#include <ATen/cuda/CUDAContext.h> |
| 10 | +constexpr int kMaxNumTensors = 32; |
| 11 | +template <typename T> |
| 12 | +struct InputJaggedTensor { |
| 13 | + T* value_list[kMaxNumTensors]; |
| 14 | + int32_t* offsets_list[kMaxNumTensors]; |
| 15 | +}; |
| 16 | + |
| 17 | + |
| 18 | +template <typename T> |
| 19 | +__global__ void concat_2D_jagged_tensors_forward_kernel( |
| 20 | + const InputJaggedTensor<T> input_jagged_tensor, |
| 21 | + const int32_t num_tensors, |
| 22 | + const int32_t num_rows, |
| 23 | + const int32_t hidden_dim, |
| 24 | + T* merged_values, |
| 25 | + int* merged_offsets) { |
| 26 | + |
| 27 | + int row = blockIdx.x * blockDim.x + threadIdx.x; |
| 28 | + if (row >= num_rows) return; |
| 29 | + int out_idx = merged_offsets[row]; |
| 30 | + |
| 31 | + for (int t = 0; t < num_tensors; ++t) { |
| 32 | + const T* values = input_jagged_tensor.value_list[t]; |
| 33 | + const int32_t* offsets = input_jagged_tensor.offsets_list[t]; |
| 34 | + int start = offsets[row]; |
| 35 | + int end = offsets[row + 1]; |
| 36 | + |
| 37 | + for (int i = start; i < end; ++i) { |
| 38 | + for (int h = 0; h < hidden_dim; ++h) { |
| 39 | + merged_values[out_idx * hidden_dim + h] = values[i * hidden_dim + h]; |
| 40 | + } |
| 41 | + out_idx++; |
| 42 | + } |
| 43 | + } |
| 44 | +} |
| 45 | + |
| 46 | +__global__ void concat_1D_jagged_tensor_kernel( |
| 47 | + const float** values_list, |
| 48 | + const int** offsets_list, |
| 49 | + int num_tensor, |
| 50 | + int num_rows,//total_length |
| 51 | + float* merged_values, |
| 52 | + int* merged_offsets){ |
| 53 | + |
| 54 | + int row = blockIdx.x * blockDim.x + threadIdx.x; |
| 55 | + if (row >= num_rows) return; |
| 56 | + |
| 57 | + int out_idx = merged_offsets[row]; // data start from this row |
| 58 | + for(int i = 0; i < num_tensor; i++){ |
| 59 | + const float* values = values_list[i]; |
| 60 | + const int* offsets = offsets_list[i]; |
| 61 | + int st = offsets[row]; |
| 62 | + int end = offsets[row+1]; |
| 63 | + for(int j = st; j < end; j++){ |
| 64 | + merged_values[out_idx++] = values[j]; |
| 65 | + } |
| 66 | + } |
| 67 | +} |
| 68 | + |
| 69 | +void concat_2D_jagged_tensors_cuda_forward ( |
| 70 | + const std::vector<torch::Tensor>& values_list, |
| 71 | + const std::vector<torch::Tensor>& offsets_list, |
| 72 | + torch::Tensor merged_values, |
| 73 | + torch::Tensor merged_offsets){ |
| 74 | + |
| 75 | + int num_tensors = values_list.size(); |
| 76 | + int num_rows = offsets_list[0].size(0) - 1; |
| 77 | + int hidden_dim = values_list[0].size(-1); |
| 78 | + |
| 79 | + InputJaggedTensor<float> input_jagged_tensor; |
| 80 | + for (int i = 0; i < num_tensors; ++i) { |
| 81 | + input_jagged_tensor.value_list[i] = values_list[i].data_ptr<float>(); |
| 82 | + input_jagged_tensor.offsets_list[i] = offsets_list[i].data_ptr<int32_t>(); |
| 83 | + } |
| 84 | + |
| 85 | + int threads = 128; |
| 86 | + int blocks = (num_rows + threads - 1) / threads; |
| 87 | + |
| 88 | + assert(merged_values.is_contiguous()); |
| 89 | + |
| 90 | + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); |
| 91 | + |
| 92 | + concat_2D_jagged_tensors_forward_kernel<float><<<blocks, threads, 0, stream>>>( |
| 93 | + input_jagged_tensor, |
| 94 | + num_tensors, |
| 95 | + num_rows, |
| 96 | + hidden_dim, |
| 97 | + merged_values.data_ptr<float>(), |
| 98 | + merged_offsets.data_ptr<int>() |
| 99 | + ); |
| 100 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 101 | + |
| 102 | + return; |
| 103 | +} |
| 104 | + |
| 105 | +template <typename T> |
| 106 | +__global__ void concat_2D_jagged_tensors_backward_kernel( |
| 107 | + const InputJaggedTensor<T> grad_jagged_tensor, |
| 108 | + const int32_t num_tensors, |
| 109 | + const int32_t num_rows, |
| 110 | + const int32_t hidden_dim, |
| 111 | + const T* grad_output, |
| 112 | + int* merged_offsets) { |
| 113 | + |
| 114 | + int row = blockIdx.x * blockDim.x + threadIdx.x; |
| 115 | + if (row >= num_rows) return; |
| 116 | + int out_idx = merged_offsets[row]; |
| 117 | + |
| 118 | + for (int t = 0; t < num_tensors; ++t) { |
| 119 | + T* grad_values = grad_jagged_tensor.value_list[t]; |
| 120 | + const int32_t* offsets = grad_jagged_tensor.offsets_list[t]; |
| 121 | + int start = offsets[row]; |
| 122 | + int end = offsets[row + 1]; |
| 123 | + for (int i = start; i < end; ++i) { |
| 124 | + for (int h = 0; h < hidden_dim; ++h) { |
| 125 | + grad_values[i * hidden_dim + h] = grad_output[out_idx * hidden_dim + h]; |
| 126 | + } |
| 127 | + out_idx++; |
| 128 | + } |
| 129 | + } |
| 130 | +} |
| 131 | + |
| 132 | +std::vector<torch::Tensor> concat_2D_jagged_tensors_cuda_backward( |
| 133 | + torch::Tensor grad_output, |
| 134 | + torch::Tensor grad_lengths, |
| 135 | + const std::vector<torch::Tensor>& offsets_list, |
| 136 | + torch::Tensor merged_offsets) { |
| 137 | + |
| 138 | + int num_tensors = offsets_list.size(); |
| 139 | + int num_rows = grad_lengths.size(0); |
| 140 | + int hidden_dim = grad_output.size(-1); |
| 141 | + |
| 142 | + std::vector<torch::Tensor> grad_inputs(num_tensors); |
| 143 | + for (int i = 0; i < num_tensors; ++i) { |
| 144 | + int tensor_size = offsets_list[i][-1].item<int>(); |
| 145 | + grad_inputs[i] = torch::empty( |
| 146 | + {tensor_size, hidden_dim}, |
| 147 | + grad_output.options() |
| 148 | + ); |
| 149 | + } |
| 150 | + |
| 151 | + InputJaggedTensor<float> grad_jagged_tensor; |
| 152 | + for (int i = 0; i < num_tensors; ++i) { |
| 153 | + grad_jagged_tensor.value_list[i] = grad_inputs[i].data_ptr<float>(); |
| 154 | + grad_jagged_tensor.offsets_list[i] = offsets_list[i].data_ptr<int32_t>(); |
| 155 | + } |
| 156 | + |
| 157 | + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); |
| 158 | + int threads = 128; |
| 159 | + int blocks = (num_rows + threads - 1) / threads; |
| 160 | + |
| 161 | + concat_2D_jagged_tensors_backward_kernel<float><<<blocks, threads, 0, stream>>>( |
| 162 | + grad_jagged_tensor, |
| 163 | + num_tensors, |
| 164 | + num_rows, |
| 165 | + hidden_dim, |
| 166 | + grad_output.data_ptr<float>(), |
| 167 | + merged_offsets.data_ptr<int>() |
| 168 | + ); |
| 169 | + C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| 170 | + |
| 171 | + return grad_inputs; |
| 172 | +} |
0 commit comments