Skip to content

Commit 20e50fb

Browse files
authored
Add fused_partial_rope op (#74577)
1 parent 1d57606 commit 20e50fb

File tree

11 files changed

+658
-0
lines changed

11 files changed

+658
-0
lines changed

paddle/phi/infermeta/fusion.cc

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2420,6 +2420,89 @@ void FusedMultiTransformerInt8InferMeta(
24202420
out->set_dtype(x.dtype());
24212421
}
24222422

2423+
void FusedPartialRopeInferMeta(const MetaTensor& x,
2424+
const MetaTensor& cos,
2425+
const MetaTensor& sin,
2426+
MetaTensor* out) {
2427+
const auto x_dims = x.dims();
2428+
PADDLE_ENFORCE_EQ(
2429+
x_dims.size(),
2430+
4,
2431+
common::errors::InvalidArgument("The input x must be a 4D tensor"));
2432+
2433+
const int64_t batch_size = x_dims[0];
2434+
const int64_t seq_len = x_dims[1];
2435+
const int64_t num_heads = x_dims[2];
2436+
const int64_t head_dim = x_dims[3];
2437+
2438+
PADDLE_ENFORCE_LE(
2439+
batch_size * seq_len * num_heads,
2440+
std::numeric_limits<int>::max(),
2441+
common::errors::InvalidArgument("Currently only supports batch_size * "
2442+
"seq_len * num_heads <= INT_MAX"));
2443+
PADDLE_ENFORCE_LE(head_dim,
2444+
std::numeric_limits<int>::max(),
2445+
common::errors::InvalidArgument(
2446+
"Currently only supports head_dim <= INT_MAX"));
2447+
2448+
const auto cos_dims = cos.dims();
2449+
PADDLE_ENFORCE_EQ(
2450+
cos_dims.size(),
2451+
4,
2452+
common::errors::InvalidArgument("The input cos must be a 4D tensor"));
2453+
PADDLE_ENFORCE_EQ(
2454+
cos_dims[0],
2455+
1,
2456+
common::errors::InvalidArgument("The batch_size of cos must be 1"));
2457+
PADDLE_ENFORCE_EQ(
2458+
cos_dims[1],
2459+
seq_len,
2460+
common::errors::InvalidArgument("The seq_len of cos must match x"));
2461+
PADDLE_ENFORCE_EQ(
2462+
cos_dims[2],
2463+
1,
2464+
common::errors::InvalidArgument("The num_heads of cos must be 1"));
2465+
2466+
const int64_t pe_head_dim = cos_dims[3];
2467+
PADDLE_ENFORCE_LE(pe_head_dim,
2468+
head_dim,
2469+
common::errors::InvalidArgument(
2470+
"pe_head_dim must be no larger than head_dim"));
2471+
PADDLE_ENFORCE_EQ(
2472+
pe_head_dim % 2,
2473+
0,
2474+
common::errors::InvalidArgument("pe_head_dim must be multiple of 2"));
2475+
PADDLE_ENFORCE_LE(pe_head_dim,
2476+
1024,
2477+
common::errors::InvalidArgument(
2478+
"Currently only supports pe_head_dim <= 1024"));
2479+
2480+
const auto sin_dims = sin.dims();
2481+
PADDLE_ENFORCE_EQ(
2482+
sin_dims.size(),
2483+
4,
2484+
common::errors::InvalidArgument("The input sin must be a 4D tensor"));
2485+
PADDLE_ENFORCE_EQ(
2486+
sin_dims[0],
2487+
1,
2488+
common::errors::InvalidArgument("The batch_size of sin must be 1"));
2489+
PADDLE_ENFORCE_EQ(
2490+
sin_dims[1],
2491+
seq_len,
2492+
common::errors::InvalidArgument("The seq_len of sin must match x"));
2493+
PADDLE_ENFORCE_EQ(
2494+
sin_dims[2],
2495+
1,
2496+
common::errors::InvalidArgument("The num_heads of sin must be 1"));
2497+
PADDLE_ENFORCE_EQ(
2498+
sin_dims[3],
2499+
pe_head_dim,
2500+
common::errors::InvalidArgument("The pe_head_dim of sin must match cos"));
2501+
2502+
out->set_dims(x.dims());
2503+
out->set_dtype(x.dtype());
2504+
}
2505+
24232506
void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
24242507
const MetaTensor& input_scales,
24252508
const IntArray& tokens_per_expert,

paddle/phi/infermeta/fusion.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,11 @@ void FusedMultiTransformerInt8InferMeta(
668668
std::vector<MetaTensor*> cache_kv_out,
669669
MetaTensor* out);
670670

671+
void FusedPartialRopeInferMeta(const MetaTensor& x,
672+
const MetaTensor& cos,
673+
const MetaTensor& sin,
674+
MetaTensor* out);
675+
671676
void FusedTransposeSplitQuantInferMeta(const MetaTensor& x,
672677
const MetaTensor& input_scales,
673678
const IntArray& tokens_per_expert,
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h"
16+
17+
namespace phi {
18+
namespace fusion {
19+
20+
using FastDivMod = phi::funcs::FastDivMod<uint32_t>;
21+
22+
template <typename T, int VecSize, int NopeSize, int PeSize>
23+
__global__ void rope_grad_kernel(const T* __restrict__ cos,
24+
const T* __restrict__ sin,
25+
const T* __restrict__ out_grad,
26+
T* __restrict__ x_grad,
27+
FastDivMod seq_len,
28+
FastDivMod num_heads,
29+
uint32_t nope_head_dim,
30+
uint32_t pe_head_dim,
31+
uint32_t block_num) {
32+
using VT = phi::kps::details::VectorType<T, VecSize>;
33+
extern __shared__ T shm[];
34+
35+
const uint32_t block_idx = blockIdx.x * 8 + threadIdx.y;
36+
if (block_idx >= block_num) return;
37+
const uint32_t seq_idx = seq_len.Divmod(num_heads.Div(block_idx))[1];
38+
const size_t block_offset =
39+
static_cast<size_t>(block_idx) * (nope_head_dim + pe_head_dim);
40+
T* const pe_buffer = shm + threadIdx.y * pe_head_dim;
41+
42+
// copy nope part
43+
LOOP_WITH_SIZE_HINT(
44+
i, threadIdx.x * VecSize, nope_head_dim, 32 * VecSize, NopeSize) {
45+
size_t idx = block_offset + i;
46+
*reinterpret_cast<VT*>(x_grad + idx) =
47+
*reinterpret_cast<const VT*>(out_grad + idx);
48+
}
49+
50+
// load pe part, apply embedding and transpose in shared memory
51+
LOOP_WITH_SIZE_HINT(
52+
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
53+
VT grad = *reinterpret_cast<const VT*>(out_grad + block_offset +
54+
nope_head_dim + i);
55+
VT grad_rot;
56+
if (i < pe_head_dim / 2) {
57+
grad_rot = *reinterpret_cast<const VT*>(
58+
out_grad + block_offset + nope_head_dim + (i + pe_head_dim / 2));
59+
} else {
60+
grad_rot = *reinterpret_cast<const VT*>(
61+
out_grad + block_offset + nope_head_dim + (i - pe_head_dim / 2));
62+
}
63+
64+
VT cos_v = *reinterpret_cast<const VT*>(cos + seq_idx * pe_head_dim + i);
65+
VT sin_v;
66+
if (i < pe_head_dim / 2) {
67+
sin_v = *reinterpret_cast<const VT*>(sin + seq_idx * pe_head_dim +
68+
(i + pe_head_dim / 2));
69+
} else {
70+
sin_v = *reinterpret_cast<const VT*>(sin + seq_idx * pe_head_dim +
71+
(i - pe_head_dim / 2));
72+
}
73+
74+
for (uint32_t j = 0; j < VecSize; j++) {
75+
uint32_t pe_idx = i + j;
76+
if (pe_idx < pe_head_dim / 2) {
77+
pe_buffer[pe_idx * 2] =
78+
grad.val[j] * cos_v.val[j] + grad_rot.val[j] * sin_v.val[j];
79+
} else {
80+
pe_buffer[(pe_idx - pe_head_dim / 2) * 2 + 1] =
81+
grad.val[j] * cos_v.val[j] - grad_rot.val[j] * sin_v.val[j];
82+
}
83+
}
84+
}
85+
#ifdef PADDLE_WITH_HIP
86+
__syncthreads();
87+
#else
88+
__syncwarp();
89+
#endif
90+
91+
// store
92+
LOOP_WITH_SIZE_HINT(
93+
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
94+
VT tmp;
95+
for (uint32_t j = 0; j < VecSize; j++) {
96+
tmp.val[j] = pe_buffer[i + j];
97+
}
98+
*reinterpret_cast<VT*>(x_grad + block_offset + nope_head_dim + i) = tmp;
99+
}
100+
}
101+
102+
template <typename T, typename Context>
103+
void FusedPartialRoPEGradKernel(const Context& dev_ctx,
104+
const DenseTensor& cos,
105+
const DenseTensor& sin,
106+
const DenseTensor& out_grad,
107+
DenseTensor* x_grad) {
108+
const auto x_dims = out_grad.dims();
109+
const int64_t batch_size = x_dims[0];
110+
const int64_t seq_len = x_dims[1];
111+
const int64_t num_heads = x_dims[2];
112+
const int64_t head_dim = x_dims[3];
113+
const int64_t pe_head_dim = cos.dims()[3];
114+
const int64_t nope_head_dim = head_dim - pe_head_dim;
115+
116+
// Allocate x_grad
117+
dev_ctx.template Alloc<T>(x_grad);
118+
119+
if (batch_size == 0 || seq_len == 0 || num_heads == 0 || head_dim == 0) {
120+
return;
121+
}
122+
123+
// Launch kernel
124+
int64_t block_num = batch_size * seq_len * num_heads;
125+
dim3 grid((block_num + 7) / 8);
126+
dim3 block(32, 8);
127+
int64_t shm_size = block.y * pe_head_dim * sizeof(T);
128+
129+
auto kernel = [&]() {
130+
SWITCH_ROPE_KERNEL(nope_head_dim, pe_head_dim, {
131+
return rope_grad_kernel<T, VecSize, NopeSize, PeSize>;
132+
});
133+
}();
134+
135+
kernel<<<grid, block, shm_size, dev_ctx.stream()>>>(
136+
cos.data<T>(),
137+
sin.data<T>(),
138+
out_grad.data<T>(),
139+
x_grad->data<T>(),
140+
static_cast<uint32_t>(seq_len),
141+
static_cast<uint32_t>(num_heads),
142+
static_cast<uint32_t>(nope_head_dim),
143+
static_cast<uint32_t>(pe_head_dim),
144+
static_cast<uint32_t>(block_num));
145+
}
146+
147+
} // namespace fusion
148+
} // namespace phi
149+
150+
PD_REGISTER_KERNEL(fused_partial_rope_grad,
151+
GPU,
152+
ALL_LAYOUT,
153+
phi::fusion::FusedPartialRoPEGradKernel,
154+
phi::dtype::bfloat16) {}
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/fusion/gpu/fused_partial_rope_utils.h"
16+
17+
namespace phi {
18+
namespace fusion {
19+
20+
using FastDivMod = phi::funcs::FastDivMod<uint32_t>;
21+
22+
template <typename T, int VecSize, int NopeSize, int PeSize>
23+
__global__ void rope_kernel(const T* __restrict__ x,
24+
const T* __restrict__ cos,
25+
const T* __restrict__ sin,
26+
T* __restrict__ out,
27+
FastDivMod seq_len,
28+
FastDivMod num_heads,
29+
uint32_t nope_head_dim,
30+
uint32_t pe_head_dim,
31+
uint32_t block_num) {
32+
using VT = phi::kps::details::VectorType<T, VecSize>;
33+
extern __shared__ T shm[];
34+
35+
const uint32_t block_idx = blockIdx.x * 8 + threadIdx.y;
36+
if (block_idx >= block_num) return;
37+
const uint32_t seq_idx = seq_len.Divmod(num_heads.Div(block_idx))[1];
38+
const size_t block_offset =
39+
static_cast<size_t>(block_idx) * (nope_head_dim + pe_head_dim);
40+
T* const pe_buffer = shm + threadIdx.y * pe_head_dim;
41+
42+
// copy nope part
43+
LOOP_WITH_SIZE_HINT(
44+
i, threadIdx.x * VecSize, nope_head_dim, 32 * VecSize, NopeSize) {
45+
size_t idx = block_offset + i;
46+
*reinterpret_cast<VT*>(out + idx) = *reinterpret_cast<const VT*>(x + idx);
47+
}
48+
49+
// load pe part and transpose in shared memory
50+
LOOP_WITH_SIZE_HINT(
51+
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
52+
VT tmp = *reinterpret_cast<const VT*>(x + block_offset + nope_head_dim + i);
53+
for (uint32_t j = 0; j < VecSize; j++) {
54+
uint32_t pe_idx = i + j;
55+
if (pe_idx % 2 == 0) {
56+
pe_buffer[pe_idx / 2] = tmp.val[j];
57+
} else {
58+
pe_buffer[pe_idx / 2 + pe_head_dim / 2] = tmp.val[j];
59+
}
60+
}
61+
}
62+
#ifdef PADDLE_WITH_HIP
63+
__syncthreads();
64+
#else
65+
__syncwarp();
66+
#endif
67+
68+
// apply embedding and store
69+
LOOP_WITH_SIZE_HINT(
70+
i, threadIdx.x * VecSize, pe_head_dim, 32 * VecSize, PeSize) {
71+
VT cos_v = *reinterpret_cast<const VT*>(cos + seq_idx * pe_head_dim + i);
72+
VT sin_v = *reinterpret_cast<const VT*>(sin + seq_idx * pe_head_dim + i);
73+
VT tmp;
74+
for (uint32_t j = 0; j < VecSize; j++) {
75+
uint32_t pe_idx = i + j;
76+
T x_pe = pe_buffer[pe_idx];
77+
T x_pe_rot = (pe_idx < pe_head_dim / 2)
78+
? -pe_buffer[pe_idx + pe_head_dim / 2]
79+
: pe_buffer[pe_idx - pe_head_dim / 2];
80+
tmp.val[j] = (x_pe * cos_v.val[j]) + (x_pe_rot * sin_v.val[j]);
81+
}
82+
*reinterpret_cast<VT*>(out + block_offset + nope_head_dim + i) = tmp;
83+
}
84+
}
85+
86+
template <typename T, typename Context>
87+
void FusedPartialRoPEKernel(const Context& dev_ctx,
88+
const DenseTensor& x,
89+
const DenseTensor& cos,
90+
const DenseTensor& sin,
91+
DenseTensor* out) {
92+
const auto x_dims = x.dims();
93+
const int64_t batch_size = x_dims[0];
94+
const int64_t seq_len = x_dims[1];
95+
const int64_t num_heads = x_dims[2];
96+
const int64_t head_dim = x_dims[3];
97+
const int64_t pe_head_dim = cos.dims()[3];
98+
const int64_t nope_head_dim = head_dim - pe_head_dim;
99+
100+
// Allocate out
101+
dev_ctx.template Alloc<T>(out);
102+
103+
if (batch_size == 0 || seq_len == 0 || num_heads == 0 || head_dim == 0) {
104+
return;
105+
}
106+
107+
// Launch kernel
108+
int64_t block_num = batch_size * seq_len * num_heads;
109+
dim3 grid((block_num + 7) / 8);
110+
dim3 block(32, 8);
111+
int64_t shm_size = block.y * pe_head_dim * sizeof(T);
112+
113+
auto kernel = [&]() {
114+
SWITCH_ROPE_KERNEL(nope_head_dim, pe_head_dim, {
115+
return rope_kernel<T, VecSize, NopeSize, PeSize>;
116+
});
117+
}();
118+
119+
kernel<<<grid, block, shm_size, dev_ctx.stream()>>>(
120+
x.data<T>(),
121+
cos.data<T>(),
122+
sin.data<T>(),
123+
out->data<T>(),
124+
static_cast<uint32_t>(seq_len),
125+
static_cast<uint32_t>(num_heads),
126+
static_cast<uint32_t>(nope_head_dim),
127+
static_cast<uint32_t>(pe_head_dim),
128+
static_cast<uint32_t>(block_num));
129+
}
130+
131+
} // namespace fusion
132+
} // namespace phi
133+
134+
PD_REGISTER_KERNEL(fused_partial_rope,
135+
GPU,
136+
ALL_LAYOUT,
137+
phi::fusion::FusedPartialRoPEKernel,
138+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)