Skip to content

Commit 7fd6ffb

Browse files
add num_splist to support deterministic for flash_attn_bwd and FlashAttnUnpaddedGradKernel (#56363)
* add num_splist for flash_attn_bwd and FlashAttnUnpaddedGradKernel * Add assertTrue * Update submodule to a specific commit
1 parent eddf6d0 commit 7fd6ffb

File tree

3 files changed

+219
-6
lines changed

3 files changed

+219
-6
lines changed

paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ PD_DECLARE_bool(cudnn_deterministic);
2828

2929
namespace phi {
3030

31+
int get_num_split() {
32+
// 0 for an internal heuristic, which is optimal
33+
return FLAGS_cudnn_deterministic ? 1 : 0;
34+
}
35+
3136
template <typename T, typename Context>
3237
void FlashAttnUnpaddedGradImpl(const Context& ctx,
3338
const DenseTensor& q,
@@ -236,11 +241,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
236241
const int64_t total_k = k.dims()[0];
237242
const int64_t num_heads_k = k.dims()[1];
238243

239-
// TODO(umiswing): add deterministic in fa2.
240-
// int num_splits = 0; // 0 for an internal heuristic, which is optimal
241-
// if (FLAGS_cudnn_deterministic) {
242-
// num_splits = 1;
243-
// }
244+
int num_splits = get_num_split();
244245

245246
// TODO(umiswing): add shape check
246247
PADDLE_ENFORCE_EQ(
@@ -294,6 +295,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
294295
params.scale,
295296
params.causal,
296297
params.is_bf16,
298+
num_splits,
297299
stream,
298300
params.seed,
299301
params.offset);
@@ -401,6 +403,8 @@ void FlashAttnGradKernel(const Context& ctx,
401403
VLOG(10) << "FlashAttn bwd seed: " << params.seed
402404
<< ", offset: " << params.offset;
403405

406+
int num_splits = get_num_split();
407+
404408
bool succ = phi::dynload::flash_attn_bwd(dout.data(),
405409
q.data(),
406410
k.data(),
@@ -426,6 +430,7 @@ void FlashAttnGradKernel(const Context& ctx,
426430
params.scale,
427431
params.causal,
428432
params.is_bf16,
433+
num_splits,
429434
stream,
430435
params.seed,
431436
params.offset);
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import re
17+
import unittest
18+
19+
import numpy as np
20+
21+
import paddle
22+
import paddle.nn.functional as F
23+
from paddle.device import core
24+
from paddle.nn.functional.flash_attention import (
25+
flash_attention,
26+
scaled_dot_product_attention,
27+
)
28+
29+
30+
def get_cuda_version():
31+
result = os.popen("nvcc --version").read()
32+
regex = r'release (\S+),'
33+
match = re.search(regex, result)
34+
if match:
35+
num = str(match.group(1))
36+
integer, decimal = num.split('.')
37+
return int(integer) * 1000 + int(float(decimal) * 10)
38+
else:
39+
return -1
40+
41+
42+
def attention_naive(q, k, v, causal=False):
43+
qt = paddle.transpose(q, [0, 2, 1, 3])
44+
kt = paddle.transpose(k, [0, 2, 1, 3])
45+
vt = paddle.transpose(v, [0, 2, 1, 3])
46+
scale = 1.0 / np.sqrt(q.shape[-1])
47+
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
48+
s = paddle.scale(s, scale)
49+
p = (
50+
paddle.incubate.softmax_mask_fuse_upper_triangle(s)
51+
if causal
52+
else F.softmax(s)
53+
)
54+
o = paddle.matmul(p, vt)
55+
return paddle.transpose(o, [0, 2, 1, 3])
56+
57+
58+
is_sm8x = (
59+
core.is_compiled_with_cuda()
60+
and paddle.device.cuda.get_device_capability()[0] == 8
61+
and paddle.device.cuda.get_device_capability()[1] >= 0
62+
)
63+
64+
is_sm90 = (
65+
core.is_compiled_with_cuda()
66+
and paddle.device.cuda.get_device_capability()[0] == 9
67+
and paddle.device.cuda.get_device_capability()[1] == 0
68+
)
69+
70+
is_sm_supported = is_sm8x or is_sm90
71+
72+
73+
@unittest.skipIf(
74+
not core.is_compiled_with_cuda()
75+
or get_cuda_version() < 11040
76+
or not is_sm_supported,
77+
"core is not compiled with CUDA and cuda version need larger than or equal to 11.4"
78+
"and device's compute capability must be 8.x or 90",
79+
)
80+
class TestFlashAttentionAPIFlag(unittest.TestCase):
81+
def setUp(self):
82+
self.place = paddle.CUDAPlace(0)
83+
self.shape = (2, 128, 8, 16)
84+
self.dtype = 'float16'
85+
self.dropout = 0.0
86+
self.causal = False
87+
self.return_softmax = False
88+
self.use_sdp_kernel = False
89+
self.use_sdp_api = False
90+
91+
def flash_attn_compute(self, query, key, value):
92+
# test dynamic
93+
paddle.disable_static()
94+
95+
q = paddle.to_tensor(
96+
query, place=self.place, dtype=self.dtype, stop_gradient=False
97+
)
98+
k = paddle.to_tensor(
99+
key, place=self.place, dtype=self.dtype, stop_gradient=False
100+
)
101+
v = paddle.to_tensor(
102+
value, place=self.place, dtype=self.dtype, stop_gradient=False
103+
)
104+
105+
q_ = paddle.to_tensor(
106+
query, place=self.place, dtype=self.dtype, stop_gradient=False
107+
)
108+
k_ = paddle.to_tensor(
109+
key, place=self.place, dtype=self.dtype, stop_gradient=False
110+
)
111+
v_ = paddle.to_tensor(
112+
value, place=self.place, dtype=self.dtype, stop_gradient=False
113+
)
114+
115+
if self.use_sdp_kernel:
116+
with paddle.nn.functional.sdp_kernel(
117+
enable_math=self.enable_math,
118+
enable_flash=self.enable_flash,
119+
enable_mem_efficient=self.enable_mem_efficient,
120+
):
121+
if self.use_sdp_api:
122+
out = scaled_dot_product_attention(
123+
q, k, v, None, self.dropout, self.causal
124+
)
125+
else:
126+
out, _ = flash_attention(
127+
q, k, v, self.dropout, self.causal, self.return_softmax
128+
)
129+
130+
else:
131+
out, _ = flash_attention(
132+
q, k, v, self.dropout, self.causal, self.return_softmax
133+
)
134+
out_ = attention_naive(q_, k_, v_, self.causal)
135+
136+
out.backward()
137+
out_.backward()
138+
139+
self.assertEqual(q.grad.shape, q.shape)
140+
self.assertEqual(q_.grad.shape, q.shape)
141+
142+
np.testing.assert_allclose(
143+
q.grad.numpy(), q_.grad.numpy(), rtol=5e-03, atol=1e-03
144+
)
145+
146+
return out, out_, q.grad.numpy(), k.grad.numpy(), v.grad.numpy()
147+
148+
def test_all_flag(self):
149+
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
150+
query = np.random.random(self.shape)
151+
key = np.random.random(self.shape)
152+
value = np.random.random(self.shape)
153+
154+
out1, out1_, q_grad1, k_grad1, v_grad1 = self.flash_attn_compute(
155+
query, key, value
156+
)
157+
158+
np.testing.assert_allclose(out1.numpy(), out1_, rtol=5e-03, atol=1e-03)
159+
160+
out2, out2_, q_grad2, k_grad2, v_grad2 = self.flash_attn_compute(
161+
query, key, value
162+
)
163+
self.assertTrue(np.equal(out1.numpy(), out2.numpy()).all())
164+
self.assertTrue(np.equal(q_grad1, q_grad2).all())
165+
self.assertTrue(np.equal(k_grad1, k_grad2).all())
166+
self.assertTrue(np.equal(v_grad1, v_grad2).all())
167+
paddle.set_flags({'FLAGS_cudnn_deterministic': 0})
168+
169+
170+
class TestFlashAttentionAPIFlagTest1(TestFlashAttentionAPIFlag):
171+
def setUp(self):
172+
self.place = paddle.CUDAPlace(0)
173+
self.shape = (2, 128, 8, 16)
174+
self.dtype = paddle.float16
175+
self.dropout = 0.0
176+
self.causal = False
177+
self.return_softmax = False
178+
self.use_sdp_kernel = False
179+
180+
181+
class TestFlashAttentionAPIFlagTest2(TestFlashAttentionAPIFlag):
182+
def setUp(self):
183+
self.place = paddle.CUDAPlace(0)
184+
self.shape = (8, 1024, 16, 256)
185+
self.dtype = paddle.float16
186+
self.dropout = 0.0
187+
self.causal = False
188+
self.return_softmax = False
189+
self.use_sdp_kernel = False
190+
191+
192+
class TestSDPAttentionAPIFlagTest(TestFlashAttentionAPIFlag):
193+
def setUp(self):
194+
self.place = paddle.CUDAPlace(0)
195+
self.shape = (8, 1024, 16, 128)
196+
self.dtype = paddle.float16
197+
self.dropout = 0.0
198+
self.causal = False
199+
self.return_softmax = False
200+
self.use_sdp_kernel = True
201+
self.use_sdp_api = True
202+
self.enable_math = True
203+
self.enable_flash = False
204+
self.enable_mem_efficient = False
205+
206+
207+
if __name__ == '__main__':
208+
unittest.main()

0 commit comments

Comments
 (0)