1+ // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+ #include " helper.h"
16+
17+ template <typename T>
18+ __global__ void NeoXRotaryKernel (const T *input,
19+ const float *cos_emb,
20+ const float *sin_emb,
21+ const int *sequence_lengths,
22+ T *output,
23+ const int rotary_emb_dims,
24+ const int batch_size,
25+ const int head_num,
26+ const int seq_len,
27+ const int last_dim) {
28+ int bi = blockIdx .x ;
29+ int hi = blockIdx .y ;
30+ int si = blockIdx .z ;
31+ if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return ;
32+ int half_lastdim = last_dim / 2 ;
33+ for (int ti = threadIdx .x ; ti < half_lastdim; ti += blockDim .x ) {
34+ int base_idx = bi * head_num * seq_len * last_dim +
35+ hi * seq_len * last_dim + si * last_dim;
36+ int left_idx = base_idx + ti;
37+ const int right_idx = base_idx + ti + half_lastdim;
38+ int emb_idx_left = bi * seq_len * last_dim + si * last_dim + ti;
39+ int emb_idx_right =
40+ bi * seq_len * last_dim + si * last_dim + ti + half_lastdim;
41+ float input_left = static_cast <float >(input[left_idx]);
42+ float input_right = static_cast <float >(input[right_idx]);
43+
44+ float cos_tmp_left = cos_emb[emb_idx_left];
45+ float sin_tmp_left = sin_emb[emb_idx_left];
46+ float cos_tmp_right = cos_emb[emb_idx_right];
47+ float sin_tmp_right = sin_emb[emb_idx_right];
48+
49+ T res1 =
50+ static_cast <T>(input_left * cos_tmp_left - input_right * sin_tmp_left);
51+ T res2 = static_cast <T>(input_right * cos_tmp_right +
52+ input_left * sin_tmp_right);
53+ output[left_idx] = res1;
54+ output[right_idx] = res2;
55+ }
56+ }
57+
58+
59+ template <typename T>
60+ __global__ void RotaryKernel (const T *input,
61+ const float *cos_emb,
62+ const float *sin_emb,
63+ const int *sequence_lengths,
64+ T *output,
65+ const int rotary_emb_dims,
66+ const int batch_size,
67+ const int head_num,
68+ const int seq_len,
69+ const int last_dim) {
70+ int bi = blockIdx .x ;
71+ int hi = blockIdx .y ;
72+ int si = blockIdx .z ;
73+ if (sequence_lengths && si >= sequence_lengths[bi] * rotary_emb_dims) return ;
74+ int half_lastdim = last_dim / 2 ;
75+ // Note(ZhenyuLi): Calculate the relevant data at one time, so that no
76+ // additional space is required.
77+ for (int ti = threadIdx .x ; ti < half_lastdim; ti += blockDim .x ) {
78+ int base_idx = bi * head_num * seq_len * last_dim +
79+ hi * seq_len * last_dim + si * last_dim;
80+ int left_idx = base_idx + 2 * ti;
81+ const int right_idx = base_idx + 2 * ti + 1 ;
82+ int emb_idx = bi * seq_len * last_dim + si * last_dim + 2 * ti;
83+ float input_left = static_cast <float >(input[left_idx]);
84+ float input_right = static_cast <float >(input[right_idx]);
85+ float cos_tmp = cos_emb[emb_idx];
86+ float sin_tmp = sin_emb[emb_idx];
87+ T res1 = static_cast <T>(input_left * cos_tmp - input_right * sin_tmp);
88+ T res2 = static_cast <T>(input_right * cos_tmp + input_left * sin_tmp);
89+ output[left_idx] = res1;
90+ output[right_idx] = res2;
91+ }
92+ }
93+
94+ template <paddle::DataType D>
95+ void LaunchRotaryQK (const paddle::Tensor& q,
96+ const paddle::Tensor& kv,
97+ const paddle::Tensor& rotary_emb,
98+ const paddle::Tensor& seq_lens,
99+ const int32_t rotary_emb_dims,
100+ bool use_neox) {
101+ typedef PDTraits<D> traits_;
102+ typedef typename traits_::DataType DataType_;
103+ typedef typename traits_::data_t data_t ;
104+
105+
106+ const int32_t batch_size = q.shape ()[0 ];
107+ const int32_t head_num = q.shape ()[1 ];
108+ const int32_t seq_len = q.shape ()[2 ];
109+ const int32_t dim_head = q.shape ()[3 ];
110+
111+ auto cu_stream = q.stream ();
112+ dim3 grid (batch_size, head_num, seq_len * rotary_emb_dims);
113+ const int last_dim = dim_head / rotary_emb_dims;
114+ auto getBlockSize = [](int dim) {
115+ if (dim > 256 ) {
116+ return 512 ;
117+ } else if (dim > 128 ) {
118+ return 256 ;
119+ } else if (dim > 64 ) {
120+ return 128 ;
121+ } else if (dim > 32 ) {
122+ return 64 ;
123+ } else {
124+ return 32 ;
125+ }
126+ };
127+ int BlockSize = getBlockSize (last_dim / 2 );
128+ const float *cos_emb = rotary_emb.data <float >();
129+ const float *sin_emb = rotary_emb.data <float >() + batch_size * seq_len * dim_head;
130+
131+ const DataType_* q_data = reinterpret_cast <const DataType_*>(q.data <data_t >());
132+ const DataType_* k_data = reinterpret_cast <const DataType_*>(kv.data <data_t >());
133+
134+ DataType_* q_out_data = reinterpret_cast <DataType_*>(const_cast <data_t *>(q.data <data_t >()));
135+ DataType_* k_out_data = reinterpret_cast <DataType_*>(const_cast <data_t *>(kv.data <data_t >()));
136+
137+
138+ if (!use_neox) {
139+ RotaryKernel<<<grid, BlockSize, 0 , cu_stream>>> (
140+ q_data,
141+ cos_emb,
142+ sin_emb,
143+ seq_lens.data <int >()/* sequence_lengths*/ ,
144+ q_out_data,
145+ rotary_emb_dims,
146+ batch_size,
147+ head_num,
148+ seq_len * rotary_emb_dims,
149+ last_dim);
150+ RotaryKernel<<<grid, BlockSize, 0 , cu_stream>>> (
151+ k_data,
152+ cos_emb,
153+ sin_emb,
154+ seq_lens.data <int >()/* sequence_lengths*/ ,
155+ k_out_data,
156+ rotary_emb_dims,
157+ batch_size,
158+ head_num,
159+ seq_len * rotary_emb_dims,
160+ last_dim);
161+ } else {
162+ NeoXRotaryKernel<<<grid, BlockSize, 0 , cu_stream>>> (
163+ q_data,
164+ cos_emb,
165+ sin_emb,
166+ seq_lens.data <int >()/* sequence_lengths*/ ,
167+ q_out_data,
168+ rotary_emb_dims,
169+ batch_size,
170+ head_num,
171+ seq_len * rotary_emb_dims,
172+ last_dim);
173+ NeoXRotaryKernel<<<grid, BlockSize, 0 , cu_stream>>> (
174+ k_data,
175+ cos_emb,
176+ sin_emb,
177+ seq_lens.data <int >()/* sequence_lengths*/ ,
178+ k_out_data,
179+ rotary_emb_dims,
180+ batch_size,
181+ head_num,
182+ seq_len * rotary_emb_dims,
183+ last_dim);
184+ }
185+ }
186+
187+ void RotaryQK (const paddle::Tensor& q,
188+ const paddle::Tensor& kv,
189+ const paddle::Tensor& rotary_emb,
190+ const paddle::Tensor& seq_lens,
191+ const int32_t rotary_emb_dims,
192+ bool use_neox) {
193+ switch (q.type ()) {
194+ case paddle::DataType::BFLOAT16: {
195+ return LaunchRotaryQK<paddle::DataType::BFLOAT16>(
196+ q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
197+ );
198+ }
199+ case paddle::DataType::FLOAT16: {
200+ return LaunchRotaryQK<paddle::DataType::FLOAT16>(
201+ q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
202+ );
203+ }
204+ case paddle::DataType::FLOAT32: {
205+ return LaunchRotaryQK<paddle::DataType::FLOAT32>(
206+ q, kv, rotary_emb, seq_lens, rotary_emb_dims, use_neox
207+ );
208+ }
209+ default : {
210+ PD_THROW (
211+ " NOT supported data type. "
212+ " Only bfloat16, float16 and float32 are supported. " );
213+ break ;
214+ }
215+ }
216+ }
217+
218+
219+
220+ PD_BUILD_OP (encode_rotary_qk)
221+ .Inputs({" q" , " kv" , " rotary_emb" , " seq_lens" })
222+ .Outputs({" rotary_q_out" , " rotary_kv_out" })
223+ .SetInplaceMap({{" q" , " rotary_q_out" }, {" kv" , " rotary_kv_out" }})
224+ .Attrs({" rotary_emb_dims: int" , " use_neox: bool" })
225+ .SetKernelFn(PD_KERNEL(RotaryQK));
0 commit comments