Skip to content

Commit bc81603

Browse files
committed
Fix compile
1 parent a6edeb3 commit bc81603

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

paddle/fluid/operators/math/math_function_test.cu

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,9 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
279279
paddle::platform::float16* c =
280280
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
281281

282-
GetBlas<float16>(context).GEMM(false, false, m, n, k, float16(1), a, 3, b + 1,
283-
4, float16(1), c + 1, 4);
282+
GetBlas<paddle::platform::float16>(context).GEMM(
283+
false, false, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
284+
b + 1, 4, static_cast<paddle::platform::float16>(1), c + 1, 4);
284285

285286
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
286287

@@ -388,12 +389,9 @@ TEST(math_function, gemm_trans_cublas_fp16) {
388389
paddle::platform::float16* c =
389390
input3_gpu.mutable_data<paddle::platform::float16>(gpu_place);
390391

391-
GetBlas<float16>(context).GEMM(false, true, m, n, k, float16(1), a, 3, b + 3,
392-
3, float16(1), c + 1, 4);
393-
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext,
394-
paddle::platform::float16>(
395-
context, false, true, m, n, k, paddle::platform::float16(1), a, 3, b + 3,
396-
3, paddle::platform::float16(1), c + 1, 4);
392+
GetBlas<paddle::platform::float16>(context).GEMM(
393+
false, true, m, n, k, static_cast<paddle::platform::float16>(1), a, 3,
394+
b + 3, 3, static_cast<paddle::platform::float16>(1), c + 1, 4);
397395

398396
paddle::framework::TensorCopySync(input3_gpu, cpu_place, &input3);
399397

0 commit comments

Comments
 (0)