Skip to content

Commit b9a1d95

Browse files
[cherry-pick] Fix softmax cuda bug (#21720) (#22160)
* Fix softmax cuda bug * Refine multihead log and softmax logic * Align block to 32
1 parent 835201b commit b9a1d95

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

paddle/fluid/operators/multihead_matmul_op.cc

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,15 +84,39 @@ class MultiHeadMatMulOp : public framework::OperatorWithKernel {
8484
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0],
8585
"Multihead input bias should have same batch size");
8686

87-
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_k[1],
88-
"Multihead input bias should have same size");
89-
PADDLE_ENFORCE_EQ(dim_bias_q[1], dim_bias_v[1],
90-
"Multihead input bias should have same size");
91-
9287
auto dim_bias_qk = context->GetInputDim("BiasQK");
9388
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
9489
"Multihead input bias qk should be at least 4-D tensor.");
9590

91+
int b_indx = dim_bias_q.size() - 1;
92+
int indx = dim_q.size() - 1;
93+
94+
PADDLE_ENFORCE_EQ(
95+
dim_bias_q[b_indx], dim_q[indx],
96+
platform::errors::InvalidArgument(
97+
"bias_q's last dim size should equal to"
98+
" q last dim size, but received bias_q's size is:%d q is:%d",
99+
dim_bias_q[b_indx], dim_q[indx]));
100+
PADDLE_ENFORCE_EQ(
101+
dim_bias_k[b_indx], dim_k[indx],
102+
platform::errors::InvalidArgument(
103+
"bias_k's last dim size should equal to"
104+
" k last dim size, but received bias_k's size is:%d k is:%d",
105+
dim_bias_k[b_indx], dim_k[indx]));
106+
PADDLE_ENFORCE_EQ(
107+
dim_bias_v[b_indx], dim_v[indx],
108+
platform::errors::InvalidArgument(
109+
"bias_v's last dim size should equal to"
110+
" v last dim size, but received bias_v's size is:%d v is:%d",
111+
dim_bias_v[b_indx], dim_v[indx]));
112+
113+
PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0],
114+
platform::errors::InvalidArgument(
115+
"q should have same batch size"
116+
"with bias_qk, but received q's batch size is:%d "
117+
"bias_qk's batch size is:%d",
118+
dim_q[0], dim_bias_qk[0]));
119+
96120
int head_number = context->Attrs().Get<int>("head_number");
97121
PADDLE_ENFORCE_GT(head_number, 1,
98122
"Multihead input head number should be at least 1.");

paddle/fluid/operators/multihead_matmul_op.cu

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,14 @@ __global__ void softmax_kernel_with_eltadd(T *qk_buf_, const T *bias_qk_,
196196
const int head_num,
197197
const int seq_len,
198198
const unsigned mask) {
199-
int seq_id = blockIdx.x % seq_len;
200199
int qk_offset = blockIdx.x * seq_len;
201-
int bias_offset = blockIdx.x % (head_num * seq_len) * seq_len;
200+
assert(blockDim.x % 32 == 0);
202201

203202
__shared__ float s_sum, s_max;
204203

205204
float qk = threadIdx.x < seq_len
206205
? static_cast<float>((qk_buf_[threadIdx.x + qk_offset] +
207-
bias_qk_[threadIdx.x + bias_offset]))
206+
bias_qk_[threadIdx.x + qk_offset]))
208207
: 0.0f;
209208
float tmp = threadIdx.x < seq_len ? static_cast<float>(qk) : -1e20f;
210209

@@ -259,15 +258,29 @@ void MatMulWithHeadQK(const platform::CUDADeviceContext &context, int head_num,
259258
q_buf_, k_buf_, beta, qk_buf_, batch_size * head_num,
260259
seq_len * size_per_head, seq_len * size_per_head);
261260

262-
int m = batch_size * head_num * seq_len;
263-
int k = seq_len;
264-
265-
int grid = m;
266-
int block = k;
261+
int grid = batch_size * head_num * seq_len;
262+
int block = seq_len;
263+
264+
// Align block to 32, also limit seq_len to max block size.
265+
PADDLE_ENFORCE_LE(seq_len, 1024, platform::errors::InvalidArgument(
266+
"seq_len should <= 1024, "
267+
"but received seq_len is:%d",
268+
seq_len));
269+
if (seq_len <= 32)
270+
block = 32;
271+
else if (seq_len > 32 && seq_len <= 64)
272+
block = 64;
273+
else if (seq_len > 64 && seq_len <= 128)
274+
block = 128;
275+
else if (seq_len > 128 && seq_len <= 256)
276+
block = 256;
277+
else if (seq_len > 256 && seq_len <= 512)
278+
block = 512;
279+
else
280+
block = 1024;
267281

268-
unsigned mask = block < 32 ? (((unsigned)1 << block) - 1) : FINAL_MASK;
269282
softmax_kernel_with_eltadd<T><<<grid, block, 0, stream>>>(
270-
qk_buf_, bias_qk, batch_size, head_num, seq_len, mask);
283+
qk_buf_, bias_qk, batch_size, head_num, seq_len, FINAL_MASK);
271284
}
272285

273286
template <typename T>

python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def setUp(self):
5454
self.BiasK = np.random.random((1, w)).astype("float32")
5555
self.BiasV = np.random.random((1, w)).astype("float32")
5656
self.BiasQK = np.random.random(
57-
(1, self.head_number, self.seq_len, self.seq_len)).astype("float32")
57+
(self.batch_size, self.head_number, self.seq_len,
58+
self.seq_len)).astype("float32")
5859
# Compute Q path
5960
fc_q = self.Q + self.BiasQ
6061
reshape_q = np.reshape(fc_q, (self.batch_size, self.seq_len,

0 commit comments

Comments
 (0)