@@ -279,8 +279,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
279
279
paddle::platform::float16* c =
280
280
input3_gpu.mutable_data <paddle::platform::float16>(gpu_place);
281
281
282
- GetBlas<float16>(context).GEMM (false , false , m, n, k, float16 (1 ), a, 3 , b + 1 ,
283
- 4 , float16 (1 ), c + 1 , 4 );
282
+ GetBlas<paddle::platform::float16>(context).GEMM (
283
+ false , false , m, n, k, static_cast <paddle::platform::float16>(1 ), a, 3 ,
284
+ b + 1 , 4 , static_cast <paddle::platform::float16>(1 ), c + 1 , 4 );
284
285
285
286
paddle::framework::TensorCopySync (input3_gpu, cpu_place, &input3);
286
287
@@ -388,12 +389,9 @@ TEST(math_function, gemm_trans_cublas_fp16) {
388
389
paddle::platform::float16* c =
389
390
input3_gpu.mutable_data <paddle::platform::float16>(gpu_place);
390
391
391
- GetBlas<float16>(context).GEMM (false , true , m, n, k, float16 (1 ), a, 3 , b + 3 ,
392
- 3 , float16 (1 ), c + 1 , 4 );
393
- paddle::operators::math::gemm<paddle::platform::CUDADeviceContext,
394
- paddle::platform::float16>(
395
- context, false , true , m, n, k, paddle::platform::float16 (1 ), a, 3 , b + 3 ,
396
- 3 , paddle::platform::float16 (1 ), c + 1 , 4 );
392
+ GetBlas<paddle::platform::float16>(context).GEMM (
393
+ false , true , m, n, k, static_cast <paddle::platform::float16>(1 ), a, 3 ,
394
+ b + 3 , 3 , static_cast <paddle::platform::float16>(1 ), c + 1 , 4 );
397
395
398
396
paddle::framework::TensorCopySync (input3_gpu, cpu_place, &input3);
399
397
0 commit comments