Skip to content

Commit d5d02a3

Browse files
author
Runchu Zhao
committed
Add cuda_ops folder and cuda method for jagged_2D_tensor_concat
1 parent cfe8c92 commit d5d02a3

File tree

5 files changed

+606
-0
lines changed

5 files changed

+606
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import math
2+
import torch
3+
from typing import List, Tuple
4+
import fbgemm_gpu
5+
from torchrec.sparse.jagged_tensor import JaggedTensor
6+
7+
import jagged_tensor_op
8+
9+
class _JaggedTensorOpFunction(torch.autograd.Function):
10+
@staticmethod
11+
def forward(ctx, offsets_list: List[torch.Tensor], max_seqlens: List[int], *values_list):
12+
13+
if len(offsets_list) == 1:
14+
single_offsets = offsets_list[0]
15+
lengths = single_offsets[1:] - single_offsets[:-1]
16+
ctx.mark_non_differentiable(lengths)
17+
return values_list[0], lengths
18+
19+
dim_list = [v.size(-1) for v in values_list]
20+
assert all(dim == dim_list[0] for dim in dim_list), "All tensors must have the same value dimension"
21+
22+
with torch.cuda.nvtx.range("Calculate merged offsets", color="purple"):
23+
merged_offsets = offsets_list[0].clone()
24+
for offset_tensor in offsets_list[1:]:
25+
merged_offsets.add_(offset_tensor)
26+
27+
ctx.save_for_backward(merged_offsets, *offsets_list)
28+
total_length = merged_offsets[-1].item()
29+
hidden_dim = values_list[0].size(-1)
30+
merged_lengths = []
31+
for offsets_tensor in offsets_list:
32+
lengths = offsets_tensor[1:] - offsets_tensor[:-1]
33+
merged_lengths.append(lengths)
34+
35+
merged_lengths = torch.sum(
36+
torch.concat([lengths.view(-1, 1) for lengths in merged_lengths], dim=1), dim=1)
37+
ctx.mark_non_differentiable(merged_lengths)
38+
39+
with torch.cuda.nvtx.range("merged values mem alloc", color="purple"):
40+
merged_values = (
41+
torch.empty(
42+
(total_length, hidden_dim),
43+
dtype=values_list[0].dtype,
44+
device=values_list[0].device,
45+
)
46+
.requires_grad_(True)
47+
)
48+
49+
with torch.cuda.nvtx.range("Cpp part forward", color="purple"):
50+
jagged_tensor_op.concat_2D_jagged_tensors_forward(
51+
values_list,
52+
offsets_list,
53+
merged_values,
54+
merged_offsets
55+
)
56+
57+
return merged_values, merged_lengths
58+
59+
60+
@staticmethod
61+
def backward(ctx, grad_output, grad_lengths):
62+
merged_offsets, *offsets_list = ctx.saved_tensors
63+
grad_input = jagged_tensor_op.concat_2D_jagged_tensors_backward(grad_output, grad_lengths, offsets_list, merged_offsets)
64+
return None, None, *grad_input
65+
66+
def jagged_2D_tensor_concat(values_list: List[torch.Tensor], offsets_list: List[torch.Tensor], max_seqlens: List[int]):
67+
assert len(values_list) == len(offsets_list)
68+
return _JaggedTensorOpFunction.apply(offsets_list, max_seqlens, *values_list)
69+
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#include <pybind11/pybind11.h>
2+
#include <vector>
3+
#include <torch/extension.h>
4+
5+
void concat_2D_jagged_tensors_cuda_forward (
6+
const std::vector<torch::Tensor>& values_list,
7+
const std::vector<torch::Tensor>& offsets_list,
8+
torch::Tensor merged_values,
9+
torch::Tensor merged_offsets);
10+
11+
std::vector<torch::Tensor> concat_2D_jagged_tensors_cuda_backward(
12+
torch::Tensor grad_output,
13+
torch::Tensor grad_lengths,
14+
const std::vector<torch::Tensor>& offsets_list,
15+
torch::Tensor merged_offsets);
16+
17+
void concat_2D_jagged_tensors_forward (
18+
const std::vector<torch::Tensor>& values_list,
19+
const std::vector<torch::Tensor>& offsets_list,
20+
torch::Tensor merged_values,
21+
torch::Tensor merged_offsets) {
22+
23+
assert(merged_values.defined());
24+
concat_2D_jagged_tensors_cuda_forward(
25+
values_list,
26+
offsets_list,
27+
merged_values,
28+
merged_offsets);
29+
return;
30+
}
31+
32+
std::vector<torch::Tensor> concat_2D_jagged_tensors_backward(
33+
torch::Tensor grad_output,
34+
torch::Tensor grad_lengths,
35+
const std::vector<torch::Tensor>& offsets_list,
36+
torch::Tensor merged_offsets) {
37+
return concat_2D_jagged_tensors_cuda_backward(
38+
grad_output,
39+
grad_lengths,
40+
offsets_list,
41+
merged_offsets);
42+
}
43+
44+
PYBIND11_MODULE(jagged_tensor_op, m) {
45+
m.def("concat_2D_jagged_tensors_forward", &concat_2D_jagged_tensors_forward, "JaggedTensor concat forward (CUDA)");
46+
m.def("concat_2D_jagged_tensors_backward", &concat_2D_jagged_tensors_backward, "JaggedTensor concat backward (CUDA)");
47+
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#include <ATen/Functions.h>
2+
#include <torch/torch.h>
3+
#include <torch/extension.h>
4+
#include <vector>
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
#include <vector>
8+
#include <c10/cuda/CUDAException.h>
9+
#include <ATen/cuda/CUDAContext.h>
10+
constexpr int kMaxNumTensors = 32;
11+
template <typename T>
12+
struct InputJaggedTensor {
13+
T* value_list[kMaxNumTensors];
14+
int32_t* offsets_list[kMaxNumTensors];
15+
};
16+
17+
18+
template <typename T>
19+
__global__ void concat_2D_jagged_tensors_forward_kernel(
20+
const InputJaggedTensor<T> input_jagged_tensor,
21+
const int32_t num_tensors,
22+
const int32_t num_rows,
23+
const int32_t hidden_dim,
24+
T* merged_values,
25+
int* merged_offsets) {
26+
27+
int row = blockIdx.x * blockDim.x + threadIdx.x;
28+
if (row >= num_rows) return;
29+
int out_idx = merged_offsets[row];
30+
31+
for (int t = 0; t < num_tensors; ++t) {
32+
const T* values = input_jagged_tensor.value_list[t];
33+
const int32_t* offsets = input_jagged_tensor.offsets_list[t];
34+
int start = offsets[row];
35+
int end = offsets[row + 1];
36+
37+
for (int i = start; i < end; ++i) {
38+
for (int h = 0; h < hidden_dim; ++h) {
39+
merged_values[out_idx * hidden_dim + h] = values[i * hidden_dim + h];
40+
}
41+
out_idx++;
42+
}
43+
}
44+
}
45+
46+
__global__ void concat_1D_jagged_tensor_kernel(
47+
const float** values_list,
48+
const int** offsets_list,
49+
int num_tensor,
50+
int num_rows,//total_length
51+
float* merged_values,
52+
int* merged_offsets){
53+
54+
int row = blockIdx.x * blockDim.x + threadIdx.x;
55+
if (row >= num_rows) return;
56+
57+
int out_idx = merged_offsets[row]; // data start from this row
58+
for(int i = 0; i < num_tensor; i++){
59+
const float* values = values_list[i];
60+
const int* offsets = offsets_list[i];
61+
int st = offsets[row];
62+
int end = offsets[row+1];
63+
for(int j = st; j < end; j++){
64+
merged_values[out_idx++] = values[j];
65+
}
66+
}
67+
}
68+
69+
void concat_2D_jagged_tensors_cuda_forward (
70+
const std::vector<torch::Tensor>& values_list,
71+
const std::vector<torch::Tensor>& offsets_list,
72+
torch::Tensor merged_values,
73+
torch::Tensor merged_offsets){
74+
75+
int num_tensors = values_list.size();
76+
int num_rows = offsets_list[0].size(0) - 1;
77+
int hidden_dim = values_list[0].size(-1);
78+
79+
InputJaggedTensor<float> input_jagged_tensor;
80+
for (int i = 0; i < num_tensors; ++i) {
81+
input_jagged_tensor.value_list[i] = values_list[i].data_ptr<float>();
82+
input_jagged_tensor.offsets_list[i] = offsets_list[i].data_ptr<int32_t>();
83+
}
84+
85+
int threads = 128;
86+
int blocks = (num_rows + threads - 1) / threads;
87+
88+
assert(merged_values.is_contiguous());
89+
90+
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
91+
92+
concat_2D_jagged_tensors_forward_kernel<float><<<blocks, threads, 0, stream>>>(
93+
input_jagged_tensor,
94+
num_tensors,
95+
num_rows,
96+
hidden_dim,
97+
merged_values.data_ptr<float>(),
98+
merged_offsets.data_ptr<int>()
99+
);
100+
C10_CUDA_KERNEL_LAUNCH_CHECK();
101+
102+
return;
103+
}
104+
105+
template <typename T>
106+
__global__ void concat_2D_jagged_tensors_backward_kernel(
107+
const InputJaggedTensor<T> grad_jagged_tensor,
108+
const int32_t num_tensors,
109+
const int32_t num_rows,
110+
const int32_t hidden_dim,
111+
const T* grad_output,
112+
int* merged_offsets) {
113+
114+
int row = blockIdx.x * blockDim.x + threadIdx.x;
115+
if (row >= num_rows) return;
116+
int out_idx = merged_offsets[row];
117+
118+
for (int t = 0; t < num_tensors; ++t) {
119+
T* grad_values = grad_jagged_tensor.value_list[t];
120+
const int32_t* offsets = grad_jagged_tensor.offsets_list[t];
121+
int start = offsets[row];
122+
int end = offsets[row + 1];
123+
for (int i = start; i < end; ++i) {
124+
for (int h = 0; h < hidden_dim; ++h) {
125+
grad_values[i * hidden_dim + h] = grad_output[out_idx * hidden_dim + h];
126+
}
127+
out_idx++;
128+
}
129+
}
130+
}
131+
132+
std::vector<torch::Tensor> concat_2D_jagged_tensors_cuda_backward(
133+
torch::Tensor grad_output,
134+
torch::Tensor grad_lengths,
135+
const std::vector<torch::Tensor>& offsets_list,
136+
torch::Tensor merged_offsets) {
137+
138+
int num_tensors = offsets_list.size();
139+
int num_rows = grad_lengths.size(0);
140+
int hidden_dim = grad_output.size(-1);
141+
142+
std::vector<torch::Tensor> grad_inputs(num_tensors);
143+
for (int i = 0; i < num_tensors; ++i) {
144+
int tensor_size = offsets_list[i][-1].item<int>();
145+
grad_inputs[i] = torch::empty(
146+
{tensor_size, hidden_dim},
147+
grad_output.options()
148+
);
149+
}
150+
151+
InputJaggedTensor<float> grad_jagged_tensor;
152+
for (int i = 0; i < num_tensors; ++i) {
153+
grad_jagged_tensor.value_list[i] = grad_inputs[i].data_ptr<float>();
154+
grad_jagged_tensor.offsets_list[i] = offsets_list[i].data_ptr<int32_t>();
155+
}
156+
157+
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
158+
int threads = 128;
159+
int blocks = (num_rows + threads - 1) / threads;
160+
161+
concat_2D_jagged_tensors_backward_kernel<float><<<blocks, threads, 0, stream>>>(
162+
grad_jagged_tensor,
163+
num_tensors,
164+
num_rows,
165+
hidden_dim,
166+
grad_output.data_ptr<float>(),
167+
merged_offsets.data_ptr<int>()
168+
);
169+
C10_CUDA_KERNEL_LAUNCH_CHECK();
170+
171+
return grad_inputs;
172+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import os
2+
from setuptools import setup
3+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4+
5+
def nvcc_threads_args():
6+
nvcc_threads = os.getenv("NVCC_THREADS") or "4"
7+
return ["--threads", nvcc_threads]
8+
9+
nvcc_flags = [
10+
"-g",
11+
"-O3",
12+
"-std=c++17",
13+
"-U__CUDA_NO_HALF_OPERATORS__",
14+
"-U__CUDA_NO_HALF_CONVERSIONS__",
15+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
16+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
17+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
18+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
19+
"--expt-relaxed-constexpr",
20+
"--expt-extended-lambda",
21+
"--use_fast_math",
22+
]
23+
cc_flag = []
24+
cc_flag.append("-gencode")
25+
cc_flag.append("arch=compute_80,code=sm_80")
26+
setup(
27+
name='jagged_tensor_op',
28+
author='Runchu Zhao',
29+
description='JaggedTensor concat forward and backward',
30+
ext_modules=[
31+
CUDAExtension(
32+
name='jagged_tensor_op',
33+
sources=['csrc/jagged_tensor_op_cuda.cpp', 'csrc/jagged_tensor_op_kernel.cu'],
34+
extra_compile_args={
35+
"cxx": ["-O3", "-std=c++17"],
36+
# "nvcc": nvcc_threads_args() + nvcc_flags + cc_flag,
37+
"nvcc": nvcc_threads_args() + nvcc_flags,
38+
# "nvcc": ["-O2"],
39+
}
40+
)
41+
],
42+
cmdclass={
43+
'build_ext': BuildExtension
44+
}
45+
)

0 commit comments

Comments
 (0)