Skip to content

Commit a914ad9

Browse files
authored
Merge branch 'PaddlePaddle:develop' into mmlu
2 parents 3d24a42 + 6601854 commit a914ad9

File tree

296 files changed

+19754
-3881
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

296 files changed

+19754
-3881
lines changed

.github/workflows/new_issue.yml

Lines changed: 0 additions & 21 deletions
This file was deleted.

.readthedocs.yaml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
# Required
66
version: 2
7+
build:
8+
os: "ubuntu-20.04"
9+
tools:
10+
python: "3.8"
711

812
submodules:
913
include: all
@@ -19,9 +23,5 @@ sphinx:
1923

2024
# Optionally set the version of Python and requirements required to build your docs
2125
python:
22-
version: 3.8
2326
install:
2427
- requirements: docs/requirements.txt
25-
- method: setuptools
26-
path: .
27-
system_packages: true

csrc/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# PaddleNLP 自定义 OP
2+
3+
此文档介绍如何编译安装 PaddleNLP 自定义 OP。
4+
5+
## 安装 C++ 依赖
6+
7+
```shell
8+
pip install -r requirements.txt
9+
```
10+
11+
## 编译 Cuda 算子
12+
13+
```shell
14+
python setup_cuda.py install
15+
```
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
17+
template <typename T>
18+
__global__ void NeoXRotaryKernel(const T *input,
19+
const float *cos_emb,
20+
const float *sin_emb,
21+
const int *sequence_lengths,
22+
T *output,
23+
const int rotary_emb_dims,
24+
const int batch_size,
25+
const int head_num,
26+
const int seq_len,
27+
const int last_dim) {
28+
int bi = blockIdx.x;
29+
int hi = blockIdx.y;
30+
int si = blockIdx.z;
31+
if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
32+
int half_lastdim = last_dim / 2;
33+
for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) {
34+
int base_idx = bi * head_num * seq_len * last_dim +
35+
hi * seq_len * last_dim + si * last_dim;
36+
int left_idx = base_idx + ti;
37+
const int right_idx = base_idx + ti + half_lastdim;
38+
int emb_idx_left = bi * seq_len * last_dim + si * last_dim + ti;
39+
int emb_idx_right =
40+
bi * seq_len * last_dim + si * last_dim + ti + half_lastdim;
41+
float input_left = static_cast<float>(input[left_idx]);
42+
float input_right = static_cast<float>(input[right_idx]);
43+
44+
float cos_tmp_left = cos_emb[emb_idx_left];
45+
float sin_tmp_left = sin_emb[emb_idx_left];
46+
float cos_tmp_right = cos_emb[emb_idx_right];
47+
float sin_tmp_right = sin_emb[emb_idx_right];
48+
49+
T res1 =
50+
static_cast<T>(input_left * cos_tmp_left - input_right * sin_tmp_left);
51+
T res2 = static_cast<T>(input_right * cos_tmp_right +
52+
input_left * sin_tmp_right);
53+
output[left_idx] = res1;
54+
output[right_idx] = res2;
55+
}
56+
}
57+
58+
59+
template <typename T>
60+
__global__ void RotaryKernel(const T *input,
61+
const float *cos_emb,
62+
const float *sin_emb,
63+
const int *sequence_lengths,
64+
T *output,
65+
const int rotary_emb_dims,
66+
const int batch_size,
67+
const int head_num,
68+
const int seq_len,
69+
const int last_dim) {
70+
int bi = blockIdx.x;
71+
int hi = blockIdx.y;
72+
int si = blockIdx.z;
73+
if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return;
74+
int half_lastdim = last_dim / 2;
75+
// Note(ZhenyuLi): Calculate the relevant data at one time, so that no
76+
// additional space is required.
77+
for (int ti = threadIdx.x; ti < half_lastdim; ti += blockDim.x) {
78+
int base_idx = bi * head_num * seq_len * last_dim +
79+
hi * seq_len * last_dim + si * last_dim;
80+
int left_idx = base_idx + 2 * ti;
81+
const int right_idx = base_idx + 2 * ti + 1;
82+
int emb_idx = bi * seq_len * last_dim + si * last_dim + 2 * ti;
83+
float input_left = static_cast<float>(input[left_idx]);
84+
float input_right = static_cast<float>(input[right_idx]);
85+
float cos_tmp = cos_emb[emb_idx];
86+
float sin_tmp = sin_emb[emb_idx];
87+
T res1 = static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
88+
T res2 = static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
89+
output[left_idx] = res1;
90+
output[right_idx] = res2;
91+
}
92+
}
93+
94+
template <paddle::DataType D>
95+
void LaunchRotaryQK(const paddle::Tensor& q,
96+
const paddle::Tensor& kv,
97+
const paddle::Tensor& rotary_emb,
98+
const paddle::Tensor& seq_lens,
99+
const int32_t rotary_emb_dims,
100+
bool use_neox) {
101+
typedef PDTraits<D> traits_;
102+
typedef typename traits_::DataType DataType_;
103+
typedef typename traits_::data_t data_t;
104+
105+
106+
const int32_t batch_size = q.shape()[0];
107+
const int32_t head_num = q.shape()[1];
108+
const int32_t seq_len = q.shape()[2];
109+
const int32_t dim_head = q.shape()[3];
110+
111+
auto cu_stream = q.stream();
112+
dim3 grid(batch_size, head_num, seq_len * rotary_emb_dims);
113+
const int last_dim = dim_head / rotary_emb_dims;
114+
auto getBlockSize = [](int dim) {
115+
if (dim > 256) {
116+
return 512;
117+
} else if (dim > 128) {
118+
return 256;
119+
} else if (dim > 64) {
120+
return 128;
121+
} else if (dim > 32) {
122+
return 64;
123+
} else {
124+
return 32;
125+
}
126+
};
127+
int BlockSize = getBlockSize(last_dim / 2);
128+
const float *cos_emb = rotary_emb.data<float>();
129+
const float *sin_emb = rotary_emb.data<float>() + batch_size * seq_len * dim_head;
130+
131+
const DataType_* q_data = reinterpret_cast<const DataType_*>(q.data<data_t>());
132+
const DataType_* k_data = reinterpret_cast<const DataType_*>(kv.data<data_t>());
133+
134+
DataType_* q_out_data = reinterpret_cast<DataType_*>(const_cast<data_t*>(q.data<data_t>()));
135+
DataType_* k_out_data = reinterpret_cast<DataType_*>(const_cast<data_t*>(kv.data<data_t>()));
136+
137+
138+
if (!use_neox) {
139+
RotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
140+
q_data,
141+
cos_emb,
142+
sin_emb,
143+
seq_lens.data<int>()/*sequence_lengths*/,
144+
q_out_data,
145+
rotary_emb_dims,
146+
batch_size,
147+
head_num,
148+
seq_len * rotary_emb_dims,
149+
last_dim);
150+
RotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
151+
k_data,
152+
cos_emb,
153+
sin_emb,
154+
seq_lens.data<int>()/*sequence_lengths*/,
155+
k_out_data,
156+
rotary_emb_dims,
157+
batch_size,
158+
head_num,
159+
seq_len * rotary_emb_dims,
160+
last_dim);
161+
} else {
162+
NeoXRotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
163+
q_data,
164+
cos_emb,
165+
sin_emb,
166+
seq_lens.data<int>()/*sequence_lengths*/,
167+
q_out_data,
168+
rotary_emb_dims,
169+
batch_size,
170+
head_num,
171+
seq_len * rotary_emb_dims,
172+
last_dim);
173+
NeoXRotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
174+
k_data,
175+
cos_emb,
176+
sin_emb,
177+
seq_lens.data<int>()/*sequence_lengths*/,
178+
k_out_data,
179+
rotary_emb_dims,
180+
batch_size,
181+
head_num,
182+
seq_len * rotary_emb_dims,
183+
last_dim);
184+
}
185+
}
186+
187+
void RotaryQK(const paddle::Tensor& q,
188+
const paddle::Tensor& kv,
189+
const paddle::Tensor& rotary_emb,
190+
const paddle::Tensor& seq_lens,
191+
const int32_t rotary_emb_dims,
192+
bool use_neox) {
193+
switch (q.type()) {
194+
case paddle::DataType::BFLOAT16: {
195+
return LaunchRotaryQK<paddle::DataType::BFLOAT16>(
196+
q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
197+
);
198+
}
199+
case paddle::DataType::FLOAT16: {
200+
return LaunchRotaryQK<paddle::DataType::FLOAT16>(
201+
q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
202+
);
203+
}
204+
case paddle::DataType::FLOAT32: {
205+
return LaunchRotaryQK<paddle::DataType::FLOAT32>(
206+
q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
207+
);
208+
}
209+
default: {
210+
PD_THROW(
211+
"NOT supported data type. "
212+
"Only bfloat16, float16 and float32 are supported. ");
213+
break;
214+
}
215+
}
216+
}
217+
218+
219+
220+
PD_BUILD_OP(encode_rotary_qk)
221+
.Inputs({"q", "kv", "rotary_emb", "seq_lens"})
222+
.Outputs({"rotary_q_out", "rotary_kv_out"})
223+
.SetInplaceMap({{"q", "rotary_q_out"}, {"kv", "rotary_kv_out"}})
224+
.Attrs({"rotary_emb_dims: int", "use_neox: bool"})
225+
.SetKernelFn(PD_KERNEL(RotaryQK));

0 commit comments

Comments
 (0)