Skip to content

Commit 2abcf37

Browse files
authored
Merge pull request #10327 from reyoung/feature/clean_blas
Feature/clean blas
2 parents 54797ab + bc81603 commit 2abcf37

12 files changed

+398
-353
lines changed

paddle/fluid/operators/bilinear_tensor_product_op.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
6161
auto output_col_vec = output_mat.chip(i, 1);
6262
Tensor weight_mat =
6363
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
64-
math::gemm<DeviceContext, T>(dev_ctx, CblasNoTrans, CblasNoTrans,
65-
batch_size, y_dim, x_dim, 1, x->data<T>(),
66-
weight_mat.data<T>(), 0, left_mul.data<T>());
64+
math::GetBlas<DeviceContext, T>(dev_ctx).GEMM(
65+
CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data<T>(),
66+
weight_mat.data<T>(), 0, left_mul.data<T>());
6767
output_col_vec.device(place) =
6868
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
6969
}
@@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
125125
set_zero(dev_ctx, d_y, static_cast<T>(0));
126126
}
127127

128+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
129+
128130
// Caculate the Output(X@Grad) and Output(Y@Grad).
129131
if (d_x || d_y) {
130132
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
@@ -138,18 +140,16 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
138140
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
139141
.broadcast(bcast_for_x) *
140142
y_mat;
141-
math::gemm<DeviceContext, T>(
142-
dev_ctx, CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
143-
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
143+
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
144+
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
144145
}
145146
if (d_y) {
146147
x_scale_mat.device(place) =
147148
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
148149
.broadcast(bcast_for_y) *
149150
x_mat;
150-
math::gemm<DeviceContext, T>(
151-
dev_ctx, CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
152-
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
151+
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
152+
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
153153
}
154154
}
155155
}
@@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
166166
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
167167
.broadcast(bcast_for_weight) *
168168
x_mat;
169-
math::gemm<DeviceContext, T>(dev_ctx, CblasTrans, CblasNoTrans, x_dim,
170-
y_dim, batch_size, 1, x_scale.data<T>(),
171-
y->data<T>(), 0, d_weight_i.data<T>());
169+
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
170+
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
172171
}
173172
}
174173

paddle/fluid/operators/gru_unit_op.h

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
8787
const T* weight_data = weight->data<T>();
8888
T* gate_data = gate->data<T>();
8989
T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
90-
math::gemm<DeviceContext, T>(
91-
context.template device_context<DeviceContext>(), false, false,
92-
batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size,
93-
weight_data, frame_size * 2, 1, gate_data, frame_size * 3);
90+
auto blas = math::GetBlas<DeviceContext, T>(context);
91+
blas.GEMM(false, false, batch_size, 2 * frame_size, frame_size, 1,
92+
hidden_prev_data, frame_size, weight_data, frame_size * 2, 1,
93+
gate_data, frame_size * 3);
9494

9595
// calculate activited gate
9696
Eigen::array<int, 2> extents({{batch_size, frame_size}});
@@ -103,11 +103,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
103103
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
104104
auto r = g.slice(r_offsets, extents); // reset gate
105105
r_h_p.device(place) = r * h_p; // reset previous hidden state
106-
math::gemm<DeviceContext, T>(
107-
context.template device_context<DeviceContext>(), false, false,
108-
batch_size, frame_size, frame_size, 1, reset_hidden_prev_data,
109-
frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1,
110-
gate_data + frame_size * 2, frame_size * 3);
106+
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
107+
reset_hidden_prev_data, frame_size,
108+
weight_data + frame_size * frame_size * 2, frame_size, 1,
109+
gate_data + frame_size * 2, frame_size * 3);
111110

112111
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
113112
ActCompute(context.Attr<int>("activation"), place,
@@ -188,42 +187,37 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
188187
ActGradCompute(context.Attr<int>("activation"), place, c, c,
189188
d_g.slice(c_offsets, extents), d_h * u);
190189
// backward for reset_hidden_prev
191-
math::gemm<DeviceContext, T>(
192-
context.template device_context<DeviceContext>(), false, true,
193-
batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2,
194-
frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size,
195-
0, reset_hidden_prev_grad_data, frame_size);
190+
auto blas = math::GetBlas<DeviceContext, T>(context);
191+
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
192+
gate_grad_data + frame_size * 2, frame_size * 3,
193+
weight_data + frame_size * frame_size * 2, frame_size, 0,
194+
reset_hidden_prev_grad_data, frame_size);
196195
// backward for unactivated reset gate
197196
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
198197
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
199198
// backward for weight
200199
if (weight_grad) {
201200
T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
202201
// backward for state_weight
203-
math::gemm<DeviceContext, T>(
204-
context.template device_context<DeviceContext>(), true, false,
205-
frame_size, frame_size, batch_size, 1, reset_hidden_prev_data,
206-
frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0,
207-
weight_grad_data + frame_size * frame_size * 2, frame_size);
202+
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
203+
reset_hidden_prev_data, frame_size,
204+
gate_grad_data + frame_size * 2, frame_size * 3, 0,
205+
weight_grad_data + frame_size * frame_size * 2, frame_size);
208206

209207
// backward for update_gate_weight and reset_gate_weight
210-
math::gemm<DeviceContext, T>(
211-
context.template device_context<DeviceContext>(), true, false,
212-
frame_size, frame_size * 2, batch_size, 1, hidden_prev_data,
213-
frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data,
214-
frame_size * 2);
208+
blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
209+
hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0,
210+
weight_grad_data, frame_size * 2);
215211
}
216212
// backward for hidden_prev
217213
if (hidden_prev_grad) {
218214
T* hidden_prev_grad_data =
219215
hidden_prev_grad->mutable_data<T>(context.GetPlace());
220216
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
221217
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
222-
math::gemm<DeviceContext, T>(
223-
context.template device_context<DeviceContext>(), false, true,
224-
batch_size, frame_size, frame_size * 2, 1, gate_grad_data,
225-
frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data,
226-
frame_size);
218+
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
219+
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
220+
hidden_prev_grad_data, frame_size);
227221
}
228222
// backward for input
229223
if (input_grad) {
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/operators/math/math_function.h"
18+
#include "paddle/fluid/platform/dynload/cublas.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace math {
23+
24+
template <typename T>
25+
struct CUBlas;
26+
27+
template <>
28+
struct CUBlas<float> {
29+
template <typename... ARGS>
30+
static void GEMM(ARGS... args) {
31+
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
32+
}
33+
};
34+
35+
template <>
36+
struct CUBlas<double> {
37+
template <typename... ARGS>
38+
static void GEMM(ARGS... args) {
39+
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
40+
}
41+
};
42+
43+
template <>
44+
struct CUBlas<platform::float16> {
45+
using float16 = platform::float16;
46+
47+
static void GEMM(cublasHandle_t handle, cublasOperation_t transa,
48+
cublasOperation_t transb, int m, int n, int k,
49+
const float16 *alpha, const float16 *A, int lda,
50+
const float16 *B, int ldb, const float16 *beta, float16 *C,
51+
int ldc) {
52+
PADDLE_ENFORCE(
53+
platform::dynload::cublasHgemm(handle, transa, transb, m, n, k,
54+
reinterpret_cast<const __half *>(alpha),
55+
reinterpret_cast<const __half *>(A), lda,
56+
reinterpret_cast<const __half *>(B), ldb,
57+
reinterpret_cast<const __half *>(beta),
58+
reinterpret_cast<__half *>(C), ldc));
59+
}
60+
};
61+
62+
template <>
63+
template <typename T>
64+
void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
65+
const CBLAS_TRANSPOSE transB,
66+
const int M, const int N,
67+
const int K, const T alpha,
68+
const T *A, const T *B,
69+
const T beta, T *C) const {
70+
// Note that cublas follows fortran order, so the order is different from
71+
// the cblas convention.
72+
int lda = (transA == CblasNoTrans) ? K : M;
73+
int ldb = (transB == CblasNoTrans) ? N : K;
74+
cublasOperation_t cuTransA =
75+
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
76+
cublasOperation_t cuTransB =
77+
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
78+
79+
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
80+
B, ldb, A, lda, &beta, C, N);
81+
}
82+
83+
template <>
84+
template <>
85+
inline void Blas<platform::CUDADeviceContext>::GEMM(
86+
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
87+
const int N, const int K, const platform::float16 alpha,
88+
const platform::float16 *A, const platform::float16 *B,
89+
const platform::float16 beta, platform::float16 *C) const {
90+
// Note that cublas follows fortran order, so the order is different from
91+
// the cblas convention.
92+
int lda = (transA == CblasNoTrans) ? K : M;
93+
int ldb = (transB == CblasNoTrans) ? N : K;
94+
cublasOperation_t cuTransA =
95+
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
96+
cublasOperation_t cuTransB =
97+
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
98+
99+
// TODO(kexinzhao): add processing code for compute capability < 53 case
100+
PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
101+
"cublas fp16 gemm requires GPU compute capability >= 53");
102+
103+
#if CUDA_VERSION >= 8000
104+
float h_alpha = static_cast<float>(alpha);
105+
float h_beta = static_cast<float>(beta);
106+
107+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
108+
#if CUDA_VERSION >= 9000
109+
if (context_.GetComputeCapability() >= 70) {
110+
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
111+
context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH));
112+
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
113+
} else {
114+
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
115+
context_.cublas_handle(), CUBLAS_DEFAULT_MATH));
116+
}
117+
#endif // CUDA_VERSION >= 9000
118+
119+
// cublasHgemm does true FP16 computation which is slow for non-Volta
120+
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
121+
// input/output in fp16, computation in fp32, which can also be accelerated
122+
// using tensor cores in volta GPUs.
123+
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
124+
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
125+
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
126+
CUDA_R_32F, algo));
127+
#else
128+
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
129+
CUBlas<platform::float16>::GEMM(context_.cublas_handle(), cuTransB, cuTransA,
130+
N, M, K, &h_alpha, h_B, ldb, h_A, lda,
131+
&h_beta, h_C, N);
132+
#endif // CUDA_VERSION >= 8000
133+
}
134+
135+
template <>
136+
template <typename T>
137+
void Blas<platform::CUDADeviceContext>::GEMM(
138+
const bool transA, const bool transB, const int M, const int N, const int K,
139+
const T alpha, const T *A, const int lda, const T *B, const int ldb,
140+
const T beta, T *C, const int ldc) const {
141+
// Note that cublas follows fortran order, so the order is different from
142+
// the cblas convention.
143+
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
144+
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
145+
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
146+
B, ldb, A, lda, &beta, C, ldc);
147+
}
148+
149+
} // namespace math
150+
} // namespace operators
151+
} // namespace paddle
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
#include "paddle/fluid/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
22+
template <typename T>
23+
struct CBlas;
24+
25+
template <>
26+
struct CBlas<float> {
27+
template <typename... ARGS>
28+
static void GEMM(ARGS... args) {
29+
cblas_sgemm(args...);
30+
}
31+
};
32+
33+
template <>
34+
struct CBlas<double> {
35+
template <typename... ARGS>
36+
static void GEMM(ARGS... args) {
37+
cblas_dgemm(args...);
38+
}
39+
};
40+
41+
template <>
42+
struct CBlas<platform::float16> {
43+
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
44+
};
45+
46+
template <>
47+
template <typename T>
48+
void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
49+
const CBLAS_TRANSPOSE transB,
50+
const int M, const int N,
51+
const int K, const T alpha,
52+
const T *A, const T *B,
53+
const T beta, T *C) const {
54+
int lda = (transA == CblasNoTrans) ? K : M;
55+
int ldb = (transB == CblasNoTrans) ? N : K;
56+
int ldc = N;
57+
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
58+
beta, C, ldc);
59+
}
60+
61+
template <>
62+
template <typename T>
63+
void Blas<platform::CPUDeviceContext>::GEMM(
64+
const bool transA, const bool transB, const int M, const int N, const int K,
65+
const T alpha, const T *A, const int lda, const T *B, const int ldb,
66+
const T beta, T *C, const int ldc) const {
67+
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
68+
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
69+
lda, B, ldb, beta, C, ldc);
70+
}
71+
72+
} // namespace math
73+
} // namespace operators
74+
} // namespace paddle

0 commit comments

Comments
 (0)