Skip to content

Commit 49dedfa

Browse files
committed
Polish code and tests
1 parent c888e01 commit 49dedfa

File tree

3 files changed

+59
-37
lines changed

3 files changed

+59
-37
lines changed

paddle/fluid/operators/math/blas_impl.cu.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,20 @@ struct CUBlas<double> {
4242

4343
template <>
4444
struct CUBlas<platform::float16> {
45-
template <typename... ARGS>
46-
static void GEMM(ARGS... args) {
47-
PADDLE_ENFORCE(platform::dynload::cublasHgemm(args...));
45+
using float16 = platform::float16;
46+
47+
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
48+
cublasOperation_t transb, int m, int n, int k,
49+
const float16 *alpha, const float16 *A, int lda,
50+
const float16 *B, int ldb, const float16 *beta, float16 *C,
51+
int ldc) {
52+
PADDLE_ENFORCE(
53+
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
54+
reinterpret_cast<const __half *>(alpha),
55+
reinterpret_cast<const __half *>(A), lda,
56+
reinterpret_cast<const __half *>(B), ldb,
57+
reinterpret_cast<const __half *>(beta),
58+
reinterpret_cast<__half *>(C), ldc));
4859
}
4960
};
5061

paddle/fluid/operators/math/math_function_test.cc

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
#include "paddle/fluid/operators/math/math_function.h"
1515
#include "gtest/gtest.h"
1616

17+
template <typename T>
18+
inline paddle::operators::math::BlasT<paddle::platform::CPUDeviceContext, T>
19+
GetBlas(const paddle::platform::CPUDeviceContext& context) {
20+
return paddle::operators::math::GetBlas<paddle::platform::CPUDeviceContext,
21+
T>(context);
22+
}
23+
1724
TEST(math_function, gemm_notrans_cblas) {
1825
paddle::framework::Tensor input1;
1926
paddle::framework::Tensor input2;
@@ -34,9 +41,8 @@ TEST(math_function, gemm_notrans_cblas) {
3441
memcpy(input3_ptr, arr3, 8 * sizeof(float));
3542

3643
paddle::platform::CPUDeviceContext context(*cpu_place);
37-
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>(
38-
context, false, false, m, n, k, 1, input1_ptr, 3, input2_ptr + 1, 4, 1,
39-
input3_ptr + 1, 4);
44+
GetBlas<float>(context).GEMM(false, false, m, n, k, 1, input1_ptr, 3,
45+
input2_ptr + 1, 4, 1, input3_ptr + 1, 4);
4046

4147
EXPECT_EQ(input3_ptr[0], 0);
4248
EXPECT_EQ(input3_ptr[1], 24);
@@ -68,9 +74,8 @@ TEST(math_function, gemm_trans_clbas) {
6874
memcpy(input3_ptr, arr3, 8 * sizeof(float));
6975

7076
paddle::platform::CPUDeviceContext context(*cpu_place);
71-
paddle::operators::math::gemm<paddle::platform::CPUDeviceContext, float>(
72-
context, false, true, m, n, k, 1, input1_ptr, 3, input2_ptr + 3, 3, 1,
73-
input3_ptr + 1, 4);
77+
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, input1_ptr, 3,
78+
input2_ptr + 3, 3, 1, input3_ptr + 1, 4);
7479

7580
EXPECT_EQ(input3_ptr[0], 0);
7681
EXPECT_EQ(input3_ptr[1], 24);

paddle/fluid/operators/math/math_function_test.cu

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414
#include "gtest/gtest.h"
1515
#include "paddle/fluid/operators/math/math_function.h"
16+
#include "paddle/fluid/platform/device_context.h"
1617

1718
void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
1819
const std::vector<float>& data) {
@@ -23,8 +24,8 @@ void fill_fp16_data(paddle::platform::float16* in_ptr, size_t size,
2324
}
2425

2526
TEST(math_function, notrans_mul_trans_fp32) {
26-
using namespace paddle::framework;
27-
using namespace paddle::platform;
27+
using namespace paddle::framework; // NOLINT
28+
using namespace paddle::platform; // NOLINT
2829

2930
Tensor input1;
3031
Tensor input1_gpu;
@@ -59,8 +60,8 @@ TEST(math_function, notrans_mul_trans_fp32) {
5960
}
6061

6162
TEST(math_function, notrans_mul_trans_fp16) {
62-
using namespace paddle::framework;
63-
using namespace paddle::platform;
63+
using namespace paddle::framework; // NOLINT
64+
using namespace paddle::platform; // NOLINT
6465

6566
Tensor input1;
6667
Tensor input1_gpu;
@@ -100,8 +101,8 @@ TEST(math_function, notrans_mul_trans_fp16) {
100101
}
101102

102103
TEST(math_function, trans_mul_notrans_fp32) {
103-
using namespace paddle::framework;
104-
using namespace paddle::platform;
104+
using namespace paddle::framework; // NOLINT
105+
using namespace paddle::platform; // NOLINT
105106

106107
Tensor input1;
107108
Tensor input1_gpu;
@@ -141,8 +142,8 @@ TEST(math_function, trans_mul_notrans_fp32) {
141142
}
142143

143144
TEST(math_function, trans_mul_notrans_fp16) {
144-
using namespace paddle::framework;
145-
using namespace paddle::platform;
145+
using namespace paddle::framework; // NOLINT
146+
using namespace paddle::platform; // NOLINT
146147

147148
Tensor input1;
148149
Tensor input1_gpu;
@@ -186,9 +187,16 @@ TEST(math_function, trans_mul_notrans_fp16) {
186187
EXPECT_EQ(static_cast<float>(out_ptr[8]), 29);
187188
}
188189

190+
template <typename T>
191+
inline paddle::operators::math::BlasT<paddle::platform::CUDADeviceContext, T>
192+
GetBlas(const paddle::platform::CUDADeviceContext& context) {
193+
return paddle::operators::math::GetBlas<paddle::platform::CUDADeviceContext,
194+
T>(context);
195+
}
196+
189197
TEST(math_function, gemm_notrans_cublas_fp32) {
190-
using namespace paddle::framework;
191-
using namespace paddle::platform;
198+
using namespace paddle::framework; // NOLINT
199+
using namespace paddle::platform; // NOLINT
192200

193201
Tensor input1;
194202
Tensor input2;
@@ -221,8 +229,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
221229
float* b = input2_gpu.data<float>();
222230
float* c = input3_gpu.mutable_data<float>(gpu_place);
223231

224-
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
225-
context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4);
232+
GetBlas<float>(context).GEMM(false, false, m, n, k, 1, a, 3, b + 1, 4, 1,
233+
c + 1, 4);
226234

227235
TensorCopySync(input3_gpu, cpu_place, &input3);
228236

@@ -244,8 +252,8 @@ TEST(math_function, gemm_notrans_cublas_fp32) {
244252
}
245253

246254
TEST(math_function, gemm_notrans_cublas_fp16) {
247-
using namespace paddle::framework;
248-
using namespace paddle::platform;
255+
using namespace paddle::framework; // NOLINT
256+
using namespace paddle::platform; // NOLINT
249257

250258
Tensor input1;
251259
Tensor input2;
@@ -281,9 +289,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
281289
float16* b = input2_gpu.data<float16>();
282290
float16* c = input3_gpu.mutable_data<float16>(gpu_place);
283291

284-
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>(
285-
context, false, false, m, n, k, float16(1), a, 3, b + 1, 4, float16(1),
286-
c + 1, 4);
292+
GetBlas<float16>(context).GEMM(false, false, m, n, k, float16(1), a, 3, b + 1,
293+
4, float16(1), c + 1, 4);
287294

288295
TensorCopySync(input3_gpu, cpu_place, &input3);
289296

@@ -305,8 +312,8 @@ TEST(math_function, gemm_notrans_cublas_fp16) {
305312
}
306313

307314
TEST(math_function, gemm_trans_cublas_fp32) {
308-
using namespace paddle::framework;
309-
using namespace paddle::platform;
315+
using namespace paddle::framework; // NOLINT
316+
using namespace paddle::platform; // NOLINT
310317

311318
Tensor input1;
312319
Tensor input2;
@@ -339,8 +346,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
339346
float* b = input2_gpu.data<float>();
340347
float* c = input3_gpu.mutable_data<float>(gpu_place);
341348

342-
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float>(
343-
context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4);
349+
GetBlas<float>(context).GEMM(false, true, m, n, k, 1, a, 3, b + 3, 3, 1,
350+
c + 1, 4);
344351

345352
TensorCopySync(input3_gpu, cpu_place, &input3);
346353

@@ -356,8 +363,8 @@ TEST(math_function, gemm_trans_cublas_fp32) {
356363
}
357364

358365
TEST(math_function, gemm_trans_cublas_fp16) {
359-
using namespace paddle::framework;
360-
using namespace paddle::platform;
366+
using namespace paddle::framework; // NOLINT
367+
using namespace paddle::platform; // NOLINT
361368

362369
Tensor input1;
363370
Tensor input2;
@@ -393,9 +400,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
393400
float16* b = input2_gpu.data<float16>();
394401
float16* c = input3_gpu.mutable_data<float16>(gpu_place);
395402

396-
paddle::operators::math::gemm<paddle::platform::CUDADeviceContext, float16>(
397-
context, false, true, m, n, k, float16(1), a, 3, b + 3, 3, float16(1),
398-
c + 1, 4);
403+
GetBlas<float16>(context).GEMM(false, true, m, n, k, float16(1), a, 3, b + 3,
404+
3, float16(1), c + 1, 4);
399405

400406
TensorCopySync(input3_gpu, cpu_place, &input3);
401407

@@ -412,8 +418,8 @@ TEST(math_function, gemm_trans_cublas_fp16) {
412418

413419
template <typename T>
414420
void GemvTest(int m, int n, bool trans) {
415-
using namespace paddle::framework;
416-
using namespace paddle::platform;
421+
using namespace paddle::framework; // NOLINT
422+
using namespace paddle::platform; // NOLINT
417423

418424
Tensor mat_a;
419425
Tensor vec_b;

0 commit comments

Comments
 (0)