Skip to content

Commit caa4027

Browse files
committed
Follow comments
1 parent 4db43c6 commit caa4027

File tree

1 file changed

+3
-8
lines changed

1 file changed

+3
-8
lines changed

paddle/fluid/operators/math/blas_impl.cu.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
126126
CUDA_R_32F, algo));
127127
#else
128128
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
129-
const half h_alpha = static_cast<const half>(alpha);
130-
const half h_beta = static_cast<const half>(beta);
131-
const half *h_A = reinterpret_cast<const half *>(A);
132-
const half *h_B = reinterpret_cast<const half *>(B);
133-
half *h_C = reinterpret_cast<half *>(C);
134-
135-
CUBlas<platform::float16>(context_.cublas_handle(), cuTransB, cuTransA, N, M,
136-
K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N);
129+
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
130+
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
131+
&h_beta, h_C, N);
137132
#endif // CUDA_VERSION >= 8000
138133
}
139134

0 commit comments

Comments
 (0)