Skip to content

Commit 7696ae0

Browse files
authored
[Cherry-pick] add condition of skipif (#49407)
* resolve conflict * fix format error
1 parent 2a438b0 commit 7696ae0

File tree

1 file changed

+126
-62
lines changed

1 file changed

+126
-62
lines changed

python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py

Lines changed: 126 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@
3030
from paddle.fluid import core
3131

3232

33-
@unittest.skipIf(not core.is_compiled_with_cuda(),
34-
"Paddle is not compiled with CUDA")
33+
@unittest.skipIf(
34+
not core.is_compiled_with_cuda(), "Paddle is not compiled with CUDA"
35+
)
3536
class TestFusedGateAttentionOp(OpTest):
36-
3737
def setUp(self):
3838
self.__class__.op_type = "fused_gate_attention"
3939
# use autograd to check grad in this unittest.
@@ -57,7 +57,6 @@ def config(self):
5757
self.bias_attr = True
5858

5959
def generate_input_data(self):
60-
6160
def _random(shape):
6261
if self.dtype == "bfloat16":
6362
data = np.random.random(shape).astype("float32")
@@ -67,7 +66,8 @@ def _random(shape):
6766

6867
np.random.seed(123)
6968
self.query = _random(
70-
(self.batch_size, self.msa_len, self.res_len, self.q_dim))
69+
(self.batch_size, self.msa_len, self.res_len, self.q_dim)
70+
)
7171
self.q_weight = _random((self.q_dim, self.num_heads, self.head_dim))
7272
self.k_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
7373
self.v_weight = _random((self.kv_dim, self.num_heads, self.head_dim))
@@ -80,15 +80,18 @@ def _random(shape):
8080
self.qkv_weight = np.stack([q_weight_t, k_weight_t, v_weight_t])
8181
else:
8282
self.key = _random(
83-
(self.batch_size, self.msa_len, self.m_size, self.kv_dim))
83+
(self.batch_size, self.msa_len, self.m_size, self.kv_dim)
84+
)
8485
self.qkv_weight = None
8586

8687
self.attn_mask = _random(
87-
(self.batch_size, self.msa_len, 1, 1, self.m_size))
88+
(self.batch_size, self.msa_len, 1, 1, self.m_size)
89+
)
8890

8991
if self.bias_attr:
9092
self.nonbatched_bias = _random(
91-
(self.batch_size, 1, self.num_heads, self.res_len, self.m_size))
93+
(self.batch_size, 1, self.num_heads, self.res_len, self.m_size)
94+
)
9295

9396
if self.has_gating:
9497
self.gating_w = _random((self.q_dim, self.num_heads, self.head_dim))
@@ -98,27 +101,35 @@ def _random(shape):
98101
self.output_b = _random((self.out_dim))
99102

100103
self.dout = _random(
101-
(self.batch_size, self.msa_len, self.res_len, self.q_dim))
104+
(self.batch_size, self.msa_len, self.res_len, self.q_dim)
105+
)
102106

103107
def collect_outputs(self, query, key, softmax_out, fmha_out, gate_out, out):
104108
outputs = [
105-
softmax_out, fmha_out, gate_out if self.has_gating else None, out,
106-
query.grad, None if self.merge_qkv else key.grad
109+
softmax_out,
110+
fmha_out,
111+
gate_out if self.has_gating else None,
112+
out,
113+
query.grad,
114+
None if self.merge_qkv else key.grad,
107115
]
108116
return outputs
109117

110118
def get_reference_out(self):
111119
paddle.disable_static(place=paddle.CUDAPlace(0))
112120

113121
query = paddle.to_tensor(self.query, stop_gradient=False)
114-
key = query if self.merge_qkv else paddle.to_tensor(self.key,
115-
stop_gradient=False)
122+
key = (
123+
query
124+
if self.merge_qkv
125+
else paddle.to_tensor(self.key, stop_gradient=False)
126+
)
116127
q_weight = paddle.to_tensor(self.q_weight, stop_gradient=False)
117128
k_weight = paddle.to_tensor(self.k_weight, stop_gradient=False)
118129
v_weight = paddle.to_tensor(self.v_weight, stop_gradient=False)
119130
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
120131

121-
c = self.head_dim**(-0.5)
132+
c = self.head_dim ** (-0.5)
122133
# [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
123134
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
124135
q = paddle.einsum('nbqa,ahc->nbqhc', query, q_weight) * c
@@ -136,8 +147,9 @@ def get_reference_out(self):
136147
# -> [batch_size, msa_len, num_heads, res_len, m_size]
137148
logits = logits + src_mask
138149
if self.bias_attr:
139-
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias,
140-
stop_gradient=False)
150+
nonbatched_bias = paddle.to_tensor(
151+
self.nonbatched_bias, stop_gradient=False
152+
)
141153
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size]
142154
# -> [batch_size, msa_len, num_heads, res_len, m_size]
143155
logits = logits + nonbatched_bias
@@ -159,14 +171,22 @@ def get_reference_out(self):
159171
# gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
160172
# gating_w) + gating_b
161173
gating_w_2d = paddle.reshape(
162-
gating_w, shape=[self.q_dim, self.num_heads * self.head_dim])
174+
gating_w, shape=[self.q_dim, self.num_heads * self.head_dim]
175+
)
163176
gate_values_4d = paddle.matmul(query, gating_w_2d)
164-
gate_values = paddle.reshape(
165-
gate_values_4d,
166-
shape=[
167-
self.batch_size, self.msa_len, self.res_len, self.num_heads,
168-
self.head_dim
169-
]) + gating_b
177+
gate_values = (
178+
paddle.reshape(
179+
gate_values_4d,
180+
shape=[
181+
self.batch_size,
182+
self.msa_len,
183+
self.res_len,
184+
self.num_heads,
185+
self.head_dim,
186+
],
187+
)
188+
+ gating_b
189+
)
170190
gate_values = nn.functional.sigmoid(gate_values)
171191
gate_out = fmha_out * gate_values
172192
else:
@@ -183,20 +203,32 @@ def get_reference_out(self):
183203
gate_out,
184204
shape=[
185205
self.batch_size * self.msa_len * self.res_len,
186-
self.num_heads * self.head_dim
187-
])
206+
self.num_heads * self.head_dim,
207+
],
208+
)
188209
output_w_2d = paddle.reshape(
189-
output_w, shape=[self.num_heads * self.head_dim, self.out_dim])
210+
output_w, shape=[self.num_heads * self.head_dim, self.out_dim]
211+
)
190212
out_2d = paddle.matmul(gate_out_2d, output_w_2d)
191-
out = paddle.reshape(
192-
out_2d,
193-
shape=[self.batch_size, self.msa_len, self.res_len, self.out_dim
194-
]) + output_b
195-
196-
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
197-
retain_graph=True)
198-
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
199-
out)
213+
out = (
214+
paddle.reshape(
215+
out_2d,
216+
shape=[
217+
self.batch_size,
218+
self.msa_len,
219+
self.res_len,
220+
self.out_dim,
221+
],
222+
)
223+
+ output_b
224+
)
225+
226+
paddle.autograd.backward(
227+
[out], [paddle.to_tensor(self.dout)], retain_graph=True
228+
)
229+
return self.collect_outputs(
230+
query, key, softmax_out, fmha_out, gate_out, out
231+
)
200232

201233
def get_fused_gate_attention_out(self):
202234
paddle.disable_static(place=paddle.CUDAPlace(0))
@@ -218,8 +250,9 @@ def get_fused_gate_attention_out(self):
218250
src_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True)
219251

220252
if self.bias_attr:
221-
nonbatched_bias = paddle.to_tensor(self.nonbatched_bias,
222-
stop_gradient=False)
253+
nonbatched_bias = paddle.to_tensor(
254+
self.nonbatched_bias, stop_gradient=False
255+
)
223256
else:
224257
nonbatched_bias = None
225258
if self.has_gating:
@@ -232,18 +265,42 @@ def get_fused_gate_attention_out(self):
232265
output_w = paddle.to_tensor(self.output_w, stop_gradient=False)
233266
output_b = paddle.to_tensor(self.output_b, stop_gradient=False)
234267

235-
_, _, _, _, softmax_out, fmha_out, gate_out, out = _legacy_C_ops.fused_gate_attention(
236-
query, key, q_weight, k_weight, v_weight, qkv_weight,
237-
nonbatched_bias, src_mask, gating_w, gating_b, output_w, output_b,
238-
'has_gating', self.has_gating, 'merge_qkv', self.merge_qkv)
239-
240-
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)],
241-
retain_graph=True)
242-
return self.collect_outputs(query, key, softmax_out, fmha_out, gate_out,
243-
out)
268+
(
269+
_,
270+
_,
271+
_,
272+
_,
273+
softmax_out,
274+
fmha_out,
275+
gate_out,
276+
out,
277+
) = _legacy_C_ops.fused_gate_attention(
278+
query,
279+
key,
280+
q_weight,
281+
k_weight,
282+
v_weight,
283+
qkv_weight,
284+
nonbatched_bias,
285+
src_mask,
286+
gating_w,
287+
gating_b,
288+
output_w,
289+
output_b,
290+
'has_gating',
291+
self.has_gating,
292+
'merge_qkv',
293+
self.merge_qkv,
294+
)
295+
296+
paddle.autograd.backward(
297+
[out], [paddle.to_tensor(self.dout)], retain_graph=True
298+
)
299+
return self.collect_outputs(
300+
query, key, softmax_out, fmha_out, gate_out, out
301+
)
244302

245303
def check(self, ref, out, atol, rtol, check_equal, name):
246-
247304
def _convert(value):
248305
if self.dtype == "bfloat16":
249306
return convert_uint16_to_float(value)
@@ -252,19 +309,25 @@ def _convert(value):
252309
if check_equal:
253310
self.assertTrue(
254311
np.equal(_convert(ref), _convert(out)).all(),
255-
"Checking < {} > failed!".format(name))
312+
"Checking < {} > failed!".format(name),
313+
)
256314
else:
257315
np.testing.assert_allclose(
258316
_convert(ref),
259317
_convert(out),
260318
atol=atol,
261319
rtol=rtol,
262-
err_msg="Checking < {} > failed!".format(name))
320+
err_msg="Checking < {} > failed!".format(name),
321+
)
263322

264323
def check_output_and_grad(self, atol, rtol):
265324
output_names = [
266-
"softmax_out", "fmha_out", "gate_out", "out", "query_grad",
267-
"key_grad"
325+
"softmax_out",
326+
"fmha_out",
327+
"gate_out",
328+
"out",
329+
"query_grad",
330+
"key_grad",
268331
]
269332
outputs_ref = self.get_reference_out()
270333
outputs_fused = self.get_fused_gate_attention_out()
@@ -280,22 +343,26 @@ def check_output_and_grad(self, atol, rtol):
280343
# that in fused ops, check_equal is set to False and we use allclose
281344
# to check the correctness.
282345
check_equal = False
283-
self.check(ref_res.numpy(), fused_res.numpy(), atol, rtol,
284-
check_equal, output_names[i])
346+
self.check(
347+
ref_res.numpy(),
348+
fused_res.numpy(),
349+
atol,
350+
rtol,
351+
check_equal,
352+
output_names[i],
353+
)
285354

286355
def test_output_and_grad(self):
287356
self.check_output_and_grad(atol=1e-5, rtol=1e-6)
288357

289358

290359
class TestMergeQKVLargeBatchSizeCase(TestFusedGateAttentionOp):
291-
292360
def config(self):
293361
super().config()
294362
self.batch_size = 2
295363

296364

297365
class TestSeparatedQKVCase(TestFusedGateAttentionOp):
298-
299366
def config(self):
300367
self.dtype = "float32"
301368
self.has_gating = False
@@ -312,15 +379,13 @@ def config(self):
312379

313380

314381
class TestMergeQKVNoBiasGatingCase(TestFusedGateAttentionOp):
315-
316382
def config(self):
317383
super().config()
318384
self.has_gating = False
319385
self.bias_attr = False
320386

321387

322388
class TestMergeQKVFp16Case(TestFusedGateAttentionOp):
323-
324389
def config(self):
325390
super().config()
326391
self.dtype = "float16"
@@ -332,18 +397,18 @@ def test_output_and_grad(self):
332397

333398

334399
class TestMergeQKVLargeBatchSizeFp16Case(TestMergeQKVFp16Case):
335-
336400
def config(self):
337401
super().config()
338402
self.batch_size = 2
339403

340404

341405
@unittest.skipIf(
342-
not core.is_compiled_with_cuda() or get_cuda_version() < 11000,
343-
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
406+
not core.is_compiled_with_cuda()
407+
or get_cuda_version() < 11000
408+
or paddle.device.cuda.get_device_capability()[0] < 8,
409+
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3",
344410
)
345411
class TestMergeQKVBF16Case(TestFusedGateAttentionOp):
346-
347412
def config(self):
348413
super().config()
349414
self.dtype = "bfloat16"
@@ -353,7 +418,6 @@ def test_output_and_grad(self):
353418

354419

355420
class TestMergeQKVLargeBatchSizeBF16Case(TestMergeQKVBF16Case):
356-
357421
def config(self):
358422
super().config()
359423
self.batch_size = 2

0 commit comments

Comments
 (0)