Skip to content

Commit c888e01

Browse files
committed
Refactor GEMM in blas
1 parent c93a624 commit c888e01

File tree

10 files changed

+357
-335
lines changed

10 files changed

+357
-335
lines changed

paddle/fluid/operators/bilinear_tensor_product_op.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
6161
auto output_col_vec = output_mat.chip(i, 1);
6262
Tensor weight_mat =
6363
weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim}));
64-
math::gemm<DeviceContext, T>(dev_ctx, CblasNoTrans, CblasNoTrans,
65-
batch_size, y_dim, x_dim, 1, x->data<T>(),
66-
weight_mat.data<T>(), 0, left_mul.data<T>());
64+
math::GetBlas<DeviceContext, T>(dev_ctx).GEMM(
65+
CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data<T>(),
66+
weight_mat.data<T>(), 0, left_mul.data<T>());
6767
output_col_vec.device(place) =
6868
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
6969
}
@@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
125125
set_zero(dev_ctx, d_y, static_cast<T>(0));
126126
}
127127

128+
auto blas = math::GetBlas<DeviceContext, T>(ctx);
129+
128130
// Caculate the Output(X@Grad) and Output(Y@Grad).
129131
if (d_x || d_y) {
130132
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
@@ -138,18 +140,16 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
138140
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
139141
.broadcast(bcast_for_x) *
140142
y_mat;
141-
math::gemm<DeviceContext, T>(
142-
dev_ctx, CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
143-
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
143+
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
144+
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
144145
}
145146
if (d_y) {
146147
x_scale_mat.device(place) =
147148
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
148149
.broadcast(bcast_for_y) *
149150
x_mat;
150-
math::gemm<DeviceContext, T>(
151-
dev_ctx, CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
152-
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
151+
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
152+
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
153153
}
154154
}
155155
}
@@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
166166
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
167167
.broadcast(bcast_for_weight) *
168168
x_mat;
169-
math::gemm<DeviceContext, T>(dev_ctx, CblasTrans, CblasNoTrans, x_dim,
170-
y_dim, batch_size, 1, x_scale.data<T>(),
171-
y->data<T>(), 0, d_weight_i.data<T>());
169+
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
170+
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
172171
}
173172
}
174173

paddle/fluid/operators/gru_unit_op.h

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
8787
const T* weight_data = weight->data<T>();
8888
T* gate_data = gate->data<T>();
8989
T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
90-
math::gemm<DeviceContext, T>(
91-
context.template device_context<DeviceContext>(), false, false,
92-
batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size,
93-
weight_data, frame_size * 2, 1, gate_data, frame_size * 3);
90+
auto blas = math::GetBlas<DeviceContext, T>(context);
91+
blas.GEMM(false, false, batch_size, 2 * frame_size, frame_size, 1,
92+
hidden_prev_data, frame_size, weight_data, frame_size * 2, 1,
93+
gate_data, frame_size * 3);
9494

9595
// calculate activited gate
9696
Eigen::array<int, 2> extents({{batch_size, frame_size}});
@@ -103,11 +103,10 @@ class GRUUnitKernel : public framework::OpKernel<T> {
103103
g.slice(r_offsets, extents), g.slice(r_offsets, extents));
104104
auto r = g.slice(r_offsets, extents); // reset gate
105105
r_h_p.device(place) = r * h_p; // reset previous hidden state
106-
math::gemm<DeviceContext, T>(
107-
context.template device_context<DeviceContext>(), false, false,
108-
batch_size, frame_size, frame_size, 1, reset_hidden_prev_data,
109-
frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1,
110-
gate_data + frame_size * 2, frame_size * 3);
106+
blas.GEMM(false, false, batch_size, frame_size, frame_size, 1,
107+
reset_hidden_prev_data, frame_size,
108+
weight_data + frame_size * frame_size * 2, frame_size, 1,
109+
gate_data + frame_size * 2, frame_size * 3);
111110

112111
Eigen::array<int, 2> c_offsets({{0, frame_size * 2}});
113112
ActCompute(context.Attr<int>("activation"), place,
@@ -188,42 +187,37 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
188187
ActGradCompute(context.Attr<int>("activation"), place, c, c,
189188
d_g.slice(c_offsets, extents), d_h * u);
190189
// backward for reset_hidden_prev
191-
math::gemm<DeviceContext, T>(
192-
context.template device_context<DeviceContext>(), false, true,
193-
batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2,
194-
frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size,
195-
0, reset_hidden_prev_grad_data, frame_size);
190+
auto blas = math::GetBlas<DeviceContext, T>(context);
191+
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
192+
gate_grad_data + frame_size * 2, frame_size * 3,
193+
weight_data + frame_size * frame_size * 2, frame_size, 0,
194+
reset_hidden_prev_grad_data, frame_size);
196195
// backward for unactivated reset gate
197196
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
198197
d_g.slice(r_offsets, extents), d_r_h_p * h_p);
199198
// backward for weight
200199
if (weight_grad) {
201200
T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
202201
// backward for state_weight
203-
math::gemm<DeviceContext, T>(
204-
context.template device_context<DeviceContext>(), true, false,
205-
frame_size, frame_size, batch_size, 1, reset_hidden_prev_data,
206-
frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0,
207-
weight_grad_data + frame_size * frame_size * 2, frame_size);
202+
blas.GEMM(true, false, frame_size, frame_size, batch_size, 1,
203+
reset_hidden_prev_data, frame_size,
204+
gate_grad_data + frame_size * 2, frame_size * 3, 0,
205+
weight_grad_data + frame_size * frame_size * 2, frame_size);
208206

209207
// backward for update_gate_weight and reset_gate_weight
210-
math::gemm<DeviceContext, T>(
211-
context.template device_context<DeviceContext>(), true, false,
212-
frame_size, frame_size * 2, batch_size, 1, hidden_prev_data,
213-
frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data,
214-
frame_size * 2);
208+
blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1,
209+
hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0,
210+
weight_grad_data, frame_size * 2);
215211
}
216212
// backward for hidden_prev
217213
if (hidden_prev_grad) {
218214
T* hidden_prev_grad_data =
219215
hidden_prev_grad->mutable_data<T>(context.GetPlace());
220216
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
221217
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
222-
math::gemm<DeviceContext, T>(
223-
context.template device_context<DeviceContext>(), false, true,
224-
batch_size, frame_size, frame_size * 2, 1, gate_grad_data,
225-
frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data,
226-
frame_size);
218+
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
219+
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
220+
hidden_prev_grad_data, frame_size);
227221
}
228222
// backward for input
229223
if (input_grad) {
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/operators/math/math_function.h"
18+
#include "paddle/fluid/platform/dynload/cublas.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
namespace math {
23+
24+
template <typename T>
25+
struct CUBlas;
26+
27+
template <>
28+
struct CUBlas<float> {
29+
template <typename... ARGS>
30+
static void GEMM(ARGS... args) {
31+
PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...));
32+
}
33+
};
34+
35+
template <>
36+
struct CUBlas<double> {
37+
template <typename... ARGS>
38+
static void GEMM(ARGS... args) {
39+
PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...));
40+
}
41+
};
42+
43+
template <>
44+
struct CUBlas<platform::float16> {
45+
template <typename... ARGS>
46+
static void GEMM(ARGS... args) {
47+
PADDLE_ENFORCE(platform::dynload::cublasHgemm(args...));
48+
}
49+
};
50+
51+
template <>
52+
template <typename T>
53+
void Blas<platform::CUDADeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
54+
const CBLAS_TRANSPOSE transB,
55+
const int M, const int N,
56+
const int K, const T alpha,
57+
const T *A, const T *B,
58+
const T beta, T *C) const {
59+
// Note that cublas follows fortran order, so the order is different from
60+
// the cblas convention.
61+
int lda = (transA == CblasNoTrans) ? K : M;
62+
int ldb = (transB == CblasNoTrans) ? N : K;
63+
cublasOperation_t cuTransA =
64+
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
65+
cublasOperation_t cuTransB =
66+
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
67+
68+
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
69+
B, ldb, A, lda, &beta, C, N);
70+
}
71+
72+
template <>
73+
template <>
74+
inline void Blas<platform::CUDADeviceContext>::GEMM(
75+
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
76+
const int N, const int K, const platform::float16 alpha,
77+
const platform::float16 *A, const platform::float16 *B,
78+
const platform::float16 beta, platform::float16 *C) const {
79+
// Note that cublas follows fortran order, so the order is different from
80+
// the cblas convention.
81+
int lda = (transA == CblasNoTrans) ? K : M;
82+
int ldb = (transB == CblasNoTrans) ? N : K;
83+
cublasOperation_t cuTransA =
84+
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
85+
cublasOperation_t cuTransB =
86+
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
87+
88+
// TODO(kexinzhao): add processing code for compute capability < 53 case
89+
PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53,
90+
"cublas fp16 gemm requires GPU compute capability >= 53");
91+
92+
#if CUDA_VERSION >= 8000
93+
float h_alpha = static_cast<float>(alpha);
94+
float h_beta = static_cast<float>(beta);
95+
96+
cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT;
97+
#if CUDA_VERSION >= 9000
98+
if (context_.GetComputeCapability() >= 70) {
99+
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
100+
context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH));
101+
algo = CUBLAS_GEMM_DFALT_TENSOR_OP;
102+
} else {
103+
PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(
104+
context_.cublas_handle(), CUBLAS_DEFAULT_MATH));
105+
}
106+
#endif // CUDA_VERSION >= 9000
107+
108+
// cublasHgemm does true FP16 computation which is slow for non-Volta
109+
// GPUs. So use cublasGemmEx instead which does pesudo FP16 computation:
110+
// input/output in fp16, computation in fp32, which can also be accelerated
111+
// using tensor cores in volta GPUs.
112+
PADDLE_ENFORCE(platform::dynload::cublasGemmEx(
113+
context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B,
114+
CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N,
115+
CUDA_R_32F, algo));
116+
#else
117+
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
118+
const half h_alpha = static_cast<const half>(alpha);
119+
const half h_beta = static_cast<const half>(beta);
120+
const half *h_A = reinterpret_cast<const half *>(A);
121+
const half *h_B = reinterpret_cast<const half *>(B);
122+
half *h_C = reinterpret_cast<half *>(C);
123+
124+
CUBlas<platform::float16>(context_.cublas_handle(), cuTransB, cuTransA, N, M,
125+
K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N);
126+
#endif // CUDA_VERSION >= 8000
127+
}
128+
129+
template <>
130+
template <typename T>
131+
void Blas<platform::CUDADeviceContext>::GEMM(
132+
const bool transA, const bool transB, const int M, const int N, const int K,
133+
const T alpha, const T *A, const int lda, const T *B, const int ldb,
134+
const T beta, T *C, const int ldc) const {
135+
// Note that cublas follows fortran order, so the order is different from
136+
// the cblas convention.
137+
cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T;
138+
cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T;
139+
CUBlas<T>::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha,
140+
B, ldb, A, lda, &beta, C, ldc);
141+
}
142+
143+
} // namespace math
144+
} // namespace operators
145+
} // namespace paddle
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#pragma once
15+
16+
#include "paddle/fluid/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
namespace math {
21+
22+
template <typename T>
23+
struct CBlas;
24+
25+
template <>
26+
struct CBlas<float> {
27+
static constexpr auto GEMM = cblas_sgemm;
28+
};
29+
30+
template <>
31+
struct CBlas<double> {
32+
static constexpr auto GEMM = cblas_dgemm;
33+
};
34+
35+
template <>
36+
struct CBlas<platform::float16> {
37+
void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
38+
};
39+
40+
template <>
41+
template <typename T>
42+
void Blas<platform::CPUDeviceContext>::GEMM(const CBLAS_TRANSPOSE transA,
43+
const CBLAS_TRANSPOSE transB,
44+
const int M, const int N,
45+
const int K, const T alpha,
46+
const T *A, const T *B,
47+
const T beta, T *C) const {
48+
int lda = (transA == CblasNoTrans) ? K : M;
49+
int ldb = (transB == CblasNoTrans) ? N : K;
50+
int ldc = N;
51+
CBlas<T>::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
52+
beta, C, ldc);
53+
}
54+
55+
template <>
56+
template <typename T>
57+
void Blas<platform::CPUDeviceContext>::GEMM(
58+
const bool transA, const bool transB, const int M, const int N, const int K,
59+
const T alpha, const T *A, const int lda, const T *B, const int ldb,
60+
const T beta, T *C, const int ldc) const {
61+
CBlas<T>::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans,
62+
transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A,
63+
lda, B, ldb, beta, C, ldc);
64+
}
65+
66+
} // namespace math
67+
} // namespace operators
68+
} // namespace paddle

0 commit comments

Comments
 (0)