13
13
// limitations under the License.
14
14
#include " gtest/gtest.h"
15
15
#include " paddle/fluid/operators/math/math_function.h"
16
+ #include " paddle/fluid/platform/device_context.h"
16
17
17
18
void fill_fp16_data (paddle::platform::float16* in_ptr, size_t size,
18
19
const std::vector<float >& data) {
@@ -23,8 +24,8 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
23
24
}
24
25
25
26
TEST (math_function, notrans_mul_trans_fp32) {
26
- using namespace paddle ::framework;
27
- using namespace paddle ::platform;
27
+ using namespace paddle ::framework; // NOLINT
28
+ using namespace paddle ::platform; // NOLINT
28
29
29
30
Tensor input1;
30
31
Tensor input1_gpu;
@@ -59,8 +60,8 @@ TEST(math_function, notrans_mul_trans_fp32) {
59
60
}
60
61
61
62
TEST (math_function, notrans_mul_trans_fp16) {
62
- using namespace paddle ::framework;
63
- using namespace paddle ::platform;
63
+ using namespace paddle ::framework; // NOLINT
64
+ using namespace paddle ::platform; // NOLINT
64
65
65
66
Tensor input1;
66
67
Tensor input1_gpu;
@@ -100,8 +101,8 @@ TEST(math_function, notrans_mul_trans_fp16) {
100
101
}
101
102
102
103
TEST (math_function, trans_mul_notrans_fp32) {
103
- using namespace paddle ::framework;
104
- using namespace paddle ::platform;
104
+ using namespace paddle ::framework; // NOLINT
105
+ using namespace paddle ::platform; // NOLINT
105
106
106
107
Tensor input1;
107
108
Tensor input1_gpu;
@@ -141,8 +142,8 @@ TEST(math_function, trans_mul_notrans_fp32) {
141
142
}
142
143
143
144
TEST (math_function, trans_mul_notrans_fp16) {
144
- using namespace paddle ::framework;
145
- using namespace paddle ::platform;
145
+ using namespace paddle ::framework; // NOLINT
146
+ using namespace paddle ::platform; // NOLINT
146
147
147
148
Tensor input1;
148
149
Tensor input1_gpu;
@@ -186,9 +187,16 @@ TEST(math_function, trans_mul_notrans_fp16) {
186
187
EXPECT_EQ (static_cast <float >(out_ptr[8 ]), 29 );
187
188
}
188
189
190
+ template <typename T>
191
+ inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
192
+ GetBlas (const paddle::platform::CUDADeviceContext& context) {
193
+ return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
194
+ T>(context);
195
+ }
196
+
189
197
TEST (math_function, gemm_notrans_cublas_fp32) {
190
- using namespace paddle ::framework;
191
- using namespace paddle ::platform;
198
+ using namespace paddle ::framework; // NOLINT
199
+ using namespace paddle ::platform; // NOLINT
192
200
193
201
Tensor input1;
194
202
Tensor input2;
@@ -221,8 +229,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
221
229
float * b = input2_gpu.data <float >();
222
230
float * c = input3_gpu.mutable_data <float >(gpu_place);
223
231
224
- paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float >(
225
- context, false , false , m, n, k, 1 , a, 3 , b + 1 , 4 , 1 , c + 1 , 4 );
232
+ GetBlas< float >(context). GEMM ( false , false , m, n, k, 1 , a, 3 , b + 1 , 4 , 1 ,
233
+ c + 1 , 4 );
226
234
227
235
TensorCopySync (input3_gpu, cpu_place, &input3);
228
236
@@ -244,8 +252,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
244
252
}
245
253
246
254
TEST (math_function, gemm_notrans_cublas_fp16) {
247
- using namespace paddle ::framework;
248
- using namespace paddle ::platform;
255
+ using namespace paddle ::framework; // NOLINT
256
+ using namespace paddle ::platform; // NOLINT
249
257
250
258
Tensor input1;
251
259
Tensor input2;
@@ -281,9 +289,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
281
289
float16* b = input2_gpu.data <float16>();
282
290
float16* c = input3_gpu.mutable_data <float16>(gpu_place);
283
291
284
- paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>(
285
- context, false , false , m, n, k, float16 (1 ), a, 3 , b + 1 , 4 , float16 (1 ),
286
- c + 1 , 4 );
292
+ GetBlas<float16>(context).GEMM (false , false , m, n, k, float16 (1 ), a, 3 , b + 1 ,
293
+ 4 , float16 (1 ), c + 1 , 4 );
287
294
288
295
TensorCopySync (input3_gpu, cpu_place, &input3);
289
296
@@ -305,8 +312,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
305
312
}
306
313
307
314
TEST (math_function, gemm_trans_cublas_fp32) {
308
- using namespace paddle ::framework;
309
- using namespace paddle ::platform;
315
+ using namespace paddle ::framework; // NOLINT
316
+ using namespace paddle ::platform; // NOLINT
310
317
311
318
Tensor input1;
312
319
Tensor input2;
@@ -339,8 +346,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
339
346
float * b = input2_gpu.data <float >();
340
347
float * c = input3_gpu.mutable_data <float >(gpu_place);
341
348
342
- paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float >(
343
- context, false , true , m, n, k, 1 , a, 3 , b + 3 , 3 , 1 , c + 1 , 4 );
349
+ GetBlas< float >(context). GEMM ( false , true , m, n, k, 1 , a, 3 , b + 3 , 3 , 1 ,
350
+ c + 1 , 4 );
344
351
345
352
TensorCopySync (input3_gpu, cpu_place, &input3);
346
353
@@ -356,8 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
356
363
}
357
364
358
365
TEST (math_function, gemm_trans_cublas_fp16) {
359
- using namespace paddle ::framework;
360
- using namespace paddle ::platform;
366
+ using namespace paddle ::framework; // NOLINT
367
+ using namespace paddle ::platform; // NOLINT
361
368
362
369
Tensor input1;
363
370
Tensor input2;
@@ -393,9 +400,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
393
400
float16* b = input2_gpu.data <float16>();
394
401
float16* c = input3_gpu.mutable_data <float16>(gpu_place);
395
402
396
- paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>(
397
- context, false , true , m, n, k, float16 (1 ), a, 3 , b + 3 , 3 , float16 (1 ),
398
- c + 1 , 4 );
403
+ GetBlas<float16>(context).GEMM (false , true , m, n, k, float16 (1 ), a, 3 , b + 3 ,
404
+ 3 , float16 (1 ), c + 1 , 4 );
399
405
400
406
TensorCopySync (input3_gpu, cpu_place, &input3);
401
407
@@ -412,8 +418,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
412
418
413
419
template <typename T>
414
420
void GemvTest (int m, int n, bool trans) {
415
- using namespace paddle ::framework;
416
- using namespace paddle ::platform;
421
+ using namespace paddle ::framework; // NOLINT
422
+ using namespace paddle ::platform; // NOLINT
417
423
418
424
Tensor mat_a;
419
425
Tensor vec_b;
0 commit comments