30
30
from paddle .fluid import core
31
31
32
32
33
- @unittest .skipIf (not core .is_compiled_with_cuda (),
34
- "Paddle is not compiled with CUDA" )
33
+ @unittest .skipIf (
34
+ not core .is_compiled_with_cuda (), "Paddle is not compiled with CUDA"
35
+ )
35
36
class TestFusedGateAttentionOp (OpTest ):
36
-
37
37
def setUp (self ):
38
38
self .__class__ .op_type = "fused_gate_attention"
39
39
# use autograd to check grad in this unittest.
@@ -57,7 +57,6 @@ def config(self):
57
57
self .bias_attr = True
58
58
59
59
def generate_input_data (self ):
60
-
61
60
def _random (shape ):
62
61
if self .dtype == "bfloat16" :
63
62
data = np .random .random (shape ).astype ("float32" )
@@ -67,7 +66,8 @@ def _random(shape):
67
66
68
67
np .random .seed (123 )
69
68
self .query = _random (
70
- (self .batch_size , self .msa_len , self .res_len , self .q_dim ))
69
+ (self .batch_size , self .msa_len , self .res_len , self .q_dim )
70
+ )
71
71
self .q_weight = _random ((self .q_dim , self .num_heads , self .head_dim ))
72
72
self .k_weight = _random ((self .kv_dim , self .num_heads , self .head_dim ))
73
73
self .v_weight = _random ((self .kv_dim , self .num_heads , self .head_dim ))
@@ -80,15 +80,18 @@ def _random(shape):
80
80
self .qkv_weight = np .stack ([q_weight_t , k_weight_t , v_weight_t ])
81
81
else :
82
82
self .key = _random (
83
- (self .batch_size , self .msa_len , self .m_size , self .kv_dim ))
83
+ (self .batch_size , self .msa_len , self .m_size , self .kv_dim )
84
+ )
84
85
self .qkv_weight = None
85
86
86
87
self .attn_mask = _random (
87
- (self .batch_size , self .msa_len , 1 , 1 , self .m_size ))
88
+ (self .batch_size , self .msa_len , 1 , 1 , self .m_size )
89
+ )
88
90
89
91
if self .bias_attr :
90
92
self .nonbatched_bias = _random (
91
- (self .batch_size , 1 , self .num_heads , self .res_len , self .m_size ))
93
+ (self .batch_size , 1 , self .num_heads , self .res_len , self .m_size )
94
+ )
92
95
93
96
if self .has_gating :
94
97
self .gating_w = _random ((self .q_dim , self .num_heads , self .head_dim ))
@@ -98,27 +101,35 @@ def _random(shape):
98
101
self .output_b = _random ((self .out_dim ))
99
102
100
103
self .dout = _random (
101
- (self .batch_size , self .msa_len , self .res_len , self .q_dim ))
104
+ (self .batch_size , self .msa_len , self .res_len , self .q_dim )
105
+ )
102
106
103
107
def collect_outputs (self , query , key , softmax_out , fmha_out , gate_out , out ):
104
108
outputs = [
105
- softmax_out , fmha_out , gate_out if self .has_gating else None , out ,
106
- query .grad , None if self .merge_qkv else key .grad
109
+ softmax_out ,
110
+ fmha_out ,
111
+ gate_out if self .has_gating else None ,
112
+ out ,
113
+ query .grad ,
114
+ None if self .merge_qkv else key .grad ,
107
115
]
108
116
return outputs
109
117
110
118
def get_reference_out (self ):
111
119
paddle .disable_static (place = paddle .CUDAPlace (0 ))
112
120
113
121
query = paddle .to_tensor (self .query , stop_gradient = False )
114
- key = query if self .merge_qkv else paddle .to_tensor (self .key ,
115
- stop_gradient = False )
122
+ key = (
123
+ query
124
+ if self .merge_qkv
125
+ else paddle .to_tensor (self .key , stop_gradient = False )
126
+ )
116
127
q_weight = paddle .to_tensor (self .q_weight , stop_gradient = False )
117
128
k_weight = paddle .to_tensor (self .k_weight , stop_gradient = False )
118
129
v_weight = paddle .to_tensor (self .v_weight , stop_gradient = False )
119
130
src_mask = paddle .to_tensor (self .attn_mask , stop_gradient = True )
120
131
121
- c = self .head_dim ** (- 0.5 )
132
+ c = self .head_dim ** (- 0.5 )
122
133
# [batch_size, msa_len, res_len, q_dim], [q_dim, num_heads, head_dim]
123
134
# -> [batch_size, msa_len, res_len, num_heads, head_dim]
124
135
q = paddle .einsum ('nbqa,ahc->nbqhc' , query , q_weight ) * c
@@ -136,8 +147,9 @@ def get_reference_out(self):
136
147
# -> [batch_size, msa_len, num_heads, res_len, m_size]
137
148
logits = logits + src_mask
138
149
if self .bias_attr :
139
- nonbatched_bias = paddle .to_tensor (self .nonbatched_bias ,
140
- stop_gradient = False )
150
+ nonbatched_bias = paddle .to_tensor (
151
+ self .nonbatched_bias , stop_gradient = False
152
+ )
141
153
# [batch_size, msa_len, num_heads, res_len, m_size], [batch_size, 1, num_heads, res_len, m_size]
142
154
# -> [batch_size, msa_len, num_heads, res_len, m_size]
143
155
logits = logits + nonbatched_bias
@@ -159,14 +171,22 @@ def get_reference_out(self):
159
171
# gate_values = paddle.einsum('nbqc,chv->nbqhv', query,
160
172
# gating_w) + gating_b
161
173
gating_w_2d = paddle .reshape (
162
- gating_w , shape = [self .q_dim , self .num_heads * self .head_dim ])
174
+ gating_w , shape = [self .q_dim , self .num_heads * self .head_dim ]
175
+ )
163
176
gate_values_4d = paddle .matmul (query , gating_w_2d )
164
- gate_values = paddle .reshape (
165
- gate_values_4d ,
166
- shape = [
167
- self .batch_size , self .msa_len , self .res_len , self .num_heads ,
168
- self .head_dim
169
- ]) + gating_b
177
+ gate_values = (
178
+ paddle .reshape (
179
+ gate_values_4d ,
180
+ shape = [
181
+ self .batch_size ,
182
+ self .msa_len ,
183
+ self .res_len ,
184
+ self .num_heads ,
185
+ self .head_dim ,
186
+ ],
187
+ )
188
+ + gating_b
189
+ )
170
190
gate_values = nn .functional .sigmoid (gate_values )
171
191
gate_out = fmha_out * gate_values
172
192
else :
@@ -183,20 +203,32 @@ def get_reference_out(self):
183
203
gate_out ,
184
204
shape = [
185
205
self .batch_size * self .msa_len * self .res_len ,
186
- self .num_heads * self .head_dim
187
- ])
206
+ self .num_heads * self .head_dim ,
207
+ ],
208
+ )
188
209
output_w_2d = paddle .reshape (
189
- output_w , shape = [self .num_heads * self .head_dim , self .out_dim ])
210
+ output_w , shape = [self .num_heads * self .head_dim , self .out_dim ]
211
+ )
190
212
out_2d = paddle .matmul (gate_out_2d , output_w_2d )
191
- out = paddle .reshape (
192
- out_2d ,
193
- shape = [self .batch_size , self .msa_len , self .res_len , self .out_dim
194
- ]) + output_b
195
-
196
- paddle .autograd .backward ([out ], [paddle .to_tensor (self .dout )],
197
- retain_graph = True )
198
- return self .collect_outputs (query , key , softmax_out , fmha_out , gate_out ,
199
- out )
213
+ out = (
214
+ paddle .reshape (
215
+ out_2d ,
216
+ shape = [
217
+ self .batch_size ,
218
+ self .msa_len ,
219
+ self .res_len ,
220
+ self .out_dim ,
221
+ ],
222
+ )
223
+ + output_b
224
+ )
225
+
226
+ paddle .autograd .backward (
227
+ [out ], [paddle .to_tensor (self .dout )], retain_graph = True
228
+ )
229
+ return self .collect_outputs (
230
+ query , key , softmax_out , fmha_out , gate_out , out
231
+ )
200
232
201
233
def get_fused_gate_attention_out (self ):
202
234
paddle .disable_static (place = paddle .CUDAPlace (0 ))
@@ -218,8 +250,9 @@ def get_fused_gate_attention_out(self):
218
250
src_mask = paddle .to_tensor (self .attn_mask , stop_gradient = True )
219
251
220
252
if self .bias_attr :
221
- nonbatched_bias = paddle .to_tensor (self .nonbatched_bias ,
222
- stop_gradient = False )
253
+ nonbatched_bias = paddle .to_tensor (
254
+ self .nonbatched_bias , stop_gradient = False
255
+ )
223
256
else :
224
257
nonbatched_bias = None
225
258
if self .has_gating :
@@ -232,18 +265,42 @@ def get_fused_gate_attention_out(self):
232
265
output_w = paddle .to_tensor (self .output_w , stop_gradient = False )
233
266
output_b = paddle .to_tensor (self .output_b , stop_gradient = False )
234
267
235
- _ , _ , _ , _ , softmax_out , fmha_out , gate_out , out = _legacy_C_ops .fused_gate_attention (
236
- query , key , q_weight , k_weight , v_weight , qkv_weight ,
237
- nonbatched_bias , src_mask , gating_w , gating_b , output_w , output_b ,
238
- 'has_gating' , self .has_gating , 'merge_qkv' , self .merge_qkv )
239
-
240
- paddle .autograd .backward ([out ], [paddle .to_tensor (self .dout )],
241
- retain_graph = True )
242
- return self .collect_outputs (query , key , softmax_out , fmha_out , gate_out ,
243
- out )
268
+ (
269
+ _ ,
270
+ _ ,
271
+ _ ,
272
+ _ ,
273
+ softmax_out ,
274
+ fmha_out ,
275
+ gate_out ,
276
+ out ,
277
+ ) = _legacy_C_ops .fused_gate_attention (
278
+ query ,
279
+ key ,
280
+ q_weight ,
281
+ k_weight ,
282
+ v_weight ,
283
+ qkv_weight ,
284
+ nonbatched_bias ,
285
+ src_mask ,
286
+ gating_w ,
287
+ gating_b ,
288
+ output_w ,
289
+ output_b ,
290
+ 'has_gating' ,
291
+ self .has_gating ,
292
+ 'merge_qkv' ,
293
+ self .merge_qkv ,
294
+ )
295
+
296
+ paddle .autograd .backward (
297
+ [out ], [paddle .to_tensor (self .dout )], retain_graph = True
298
+ )
299
+ return self .collect_outputs (
300
+ query , key , softmax_out , fmha_out , gate_out , out
301
+ )
244
302
245
303
def check (self , ref , out , atol , rtol , check_equal , name ):
246
-
247
304
def _convert (value ):
248
305
if self .dtype == "bfloat16" :
249
306
return convert_uint16_to_float (value )
@@ -252,19 +309,25 @@ def _convert(value):
252
309
if check_equal :
253
310
self .assertTrue (
254
311
np .equal (_convert (ref ), _convert (out )).all (),
255
- "Checking < {} > failed!" .format (name ))
312
+ "Checking < {} > failed!" .format (name ),
313
+ )
256
314
else :
257
315
np .testing .assert_allclose (
258
316
_convert (ref ),
259
317
_convert (out ),
260
318
atol = atol ,
261
319
rtol = rtol ,
262
- err_msg = "Checking < {} > failed!" .format (name ))
320
+ err_msg = "Checking < {} > failed!" .format (name ),
321
+ )
263
322
264
323
def check_output_and_grad (self , atol , rtol ):
265
324
output_names = [
266
- "softmax_out" , "fmha_out" , "gate_out" , "out" , "query_grad" ,
267
- "key_grad"
325
+ "softmax_out" ,
326
+ "fmha_out" ,
327
+ "gate_out" ,
328
+ "out" ,
329
+ "query_grad" ,
330
+ "key_grad" ,
268
331
]
269
332
outputs_ref = self .get_reference_out ()
270
333
outputs_fused = self .get_fused_gate_attention_out ()
@@ -280,22 +343,26 @@ def check_output_and_grad(self, atol, rtol):
280
343
# that in fused ops, check_equal is set to False and we use allclose
281
344
# to check the correctness.
282
345
check_equal = False
283
- self .check (ref_res .numpy (), fused_res .numpy (), atol , rtol ,
284
- check_equal , output_names [i ])
346
+ self .check (
347
+ ref_res .numpy (),
348
+ fused_res .numpy (),
349
+ atol ,
350
+ rtol ,
351
+ check_equal ,
352
+ output_names [i ],
353
+ )
285
354
286
355
def test_output_and_grad (self ):
287
356
self .check_output_and_grad (atol = 1e-5 , rtol = 1e-6 )
288
357
289
358
290
359
class TestMergeQKVLargeBatchSizeCase (TestFusedGateAttentionOp ):
291
-
292
360
def config (self ):
293
361
super ().config ()
294
362
self .batch_size = 2
295
363
296
364
297
365
class TestSeparatedQKVCase (TestFusedGateAttentionOp ):
298
-
299
366
def config (self ):
300
367
self .dtype = "float32"
301
368
self .has_gating = False
@@ -312,15 +379,13 @@ def config(self):
312
379
313
380
314
381
class TestMergeQKVNoBiasGatingCase (TestFusedGateAttentionOp ):
315
-
316
382
def config (self ):
317
383
super ().config ()
318
384
self .has_gating = False
319
385
self .bias_attr = False
320
386
321
387
322
388
class TestMergeQKVFp16Case (TestFusedGateAttentionOp ):
323
-
324
389
def config (self ):
325
390
super ().config ()
326
391
self .dtype = "float16"
@@ -332,18 +397,18 @@ def test_output_and_grad(self):
332
397
333
398
334
399
class TestMergeQKVLargeBatchSizeFp16Case (TestMergeQKVFp16Case ):
335
-
336
400
def config (self ):
337
401
super ().config ()
338
402
self .batch_size = 2
339
403
340
404
341
405
@unittest .skipIf (
342
- not core .is_compiled_with_cuda () or get_cuda_version () < 11000 ,
343
- "core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
406
+ not core .is_compiled_with_cuda ()
407
+ or get_cuda_version () < 11000
408
+ or paddle .device .cuda .get_device_capability ()[0 ] < 8 ,
409
+ "core is not compiled with CUDA and cuda version need larger than or equal to 11.3" ,
344
410
)
345
411
class TestMergeQKVBF16Case (TestFusedGateAttentionOp ):
346
-
347
412
def config (self ):
348
413
super ().config ()
349
414
self .dtype = "bfloat16"
@@ -353,7 +418,6 @@ def test_output_and_grad(self):
353
418
354
419
355
420
class TestMergeQKVLargeBatchSizeBF16Case (TestMergeQKVBF16Case ):
356
-
357
421
def config (self ):
358
422
super ().config ()
359
423
self .batch_size = 2
0 commit comments