Skip to content

Commit ef6ea79

Browse files
committed
Clean and extract blas
1 parent d0785ce commit ef6ea79

22 files changed

+398
-400
lines changed

paddle/fluid/operators/bilinear_tensor_product_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19-
#include "paddle/fluid/operators/math/math_function.h"
19+
#include "paddle/fluid/operators/math/blas.h"
2020

2121
namespace paddle {
2222
namespace operators {

paddle/fluid/operators/conv_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ limitations under the License. */
1717
#include <vector>
1818
#include "paddle/fluid/framework/eigen.h"
1919
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/fluid/operators/math/blas.h"
2021
#include "paddle/fluid/operators/math/depthwise_conv.h"
2122
#include "paddle/fluid/operators/math/im2col.h"
22-
#include "paddle/fluid/operators/math/math_function.h"
2323
#include "paddle/fluid/operators/math/vol2col.h"
2424

2525
namespace paddle {

paddle/fluid/operators/conv_transpose_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ limitations under the License. */
1616
#include <vector>
1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/blas.h"
1920
#include "paddle/fluid/operators/math/im2col.h"
20-
#include "paddle/fluid/operators/math/math_function.h"
2121
#include "paddle/fluid/operators/math/vol2col.h"
2222

2323
namespace paddle {

paddle/fluid/operators/gru_unit_op.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@ limitations under the License. */
1414

1515
#pragma once
1616

17-
#include "paddle/fluid/operators/activation_op.h"
18-
#include "paddle/fluid/operators/math/math_function.h"
19-
2017
#include "paddle/fluid/framework/eigen.h"
2118
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/activation_op.h"
20+
#include "paddle/fluid/operators/math/blas.h"
2221

2322
namespace paddle {
2423
namespace operators {

paddle/fluid/operators/layer_norm_op.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ limitations under the License. */
1515
#pragma once
1616
#include "paddle/fluid/framework/eigen.h"
1717
#include "paddle/fluid/framework/op_registry.h"
18-
1918
#include "paddle/fluid/operators/elementwise_op_function.h"
19+
#include "paddle/fluid/operators/math/blas.h"
2020
#include "paddle/fluid/operators/math/math_function.h"
2121

2222
namespace paddle {
@@ -46,9 +46,9 @@ class RowwiseMean2D<platform::CUDADeviceContext, T> {
4646
}
4747
void operator()(const platform::CUDADeviceContext& context,
4848
const framework::Tensor& input, framework::Tensor* out) {
49-
math::gemv<platform::CUDADeviceContext, T>(
50-
context, false, left_, right_, 1., input.data<T>(), divisor_.data<T>(),
51-
0., out->data<T>());
49+
math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
50+
false, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
51+
out->data<T>());
5252
}
5353

5454
private:
@@ -93,9 +93,9 @@ class ColwiseSum2D<platform::CUDADeviceContext, T> {
9393

9494
void operator()(const platform::CUDADeviceContext& context,
9595
const framework::Tensor& input, framework::Tensor* out) {
96-
math::gemv<platform::CUDADeviceContext, T>(
97-
context, true, left_, right_, 1., input.data<T>(), divisor_.data<T>(),
98-
0., out->data<T>());
96+
math::GetBlas<platform::CUDADeviceContext, T>(context).GEMV(
97+
true, left_, right_, 1., input.data<T>(), divisor_.data<T>(), 0.,
98+
out->data<T>());
9999
}
100100

101101
private:

paddle/fluid/operators/lstm_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ limitations under the License. */
1515
#pragma once
1616
#include <string>
1717
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/math/blas.h"
1819
#include "paddle/fluid/operators/math/detail/activation_functions.h"
1920
#include "paddle/fluid/operators/math/lstm_compute.h"
20-
#include "paddle/fluid/operators/math/math_function.h"
2121
#include "paddle/fluid/operators/math/sequence2batch.h"
2222

2323
namespace paddle {

paddle/fluid/operators/lstmp_op.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,14 @@ limitations under the License. */
1414

1515
#pragma once
1616
#include <string>
17+
#include "paddle/fluid/framework/eigen.h"
18+
#include "paddle/fluid/framework/op_registry.h"
1719
#include "paddle/fluid/operators/activation_op.h"
20+
#include "paddle/fluid/operators/math/blas.h"
1821
#include "paddle/fluid/operators/math/detail/activation_functions.h"
1922
#include "paddle/fluid/operators/math/lstm_compute.h"
20-
#include "paddle/fluid/operators/math/math_function.h"
2123
#include "paddle/fluid/operators/math/sequence2batch.h"
2224

23-
#include "paddle/fluid/framework/eigen.h"
24-
#include "paddle/fluid/framework/op_registry.h"
25-
2625
namespace paddle {
2726
namespace operators {
2827

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ math_library(depthwise_conv)
4141
math_library(gru_compute DEPS activation_functions math_function)
4242
math_library(im2col)
4343
math_library(lstm_compute DEPS activation_functions)
44-
math_library(math_function DEPS cblas)
44+
cc_library(blas SRCS blas.cc DEPS cblas framework_proto)
45+
math_library(math_function DEPS blas)
4546
math_library(maxouting)
4647
math_library(pooling)
4748
math_library(selected_rows_functor DEPS selected_rows math_function)

paddle/fluid/operators/math/blas.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/math/blas.h"
16+
namespace paddle {
17+
namespace operators {
18+
namespace math {
19+
// Do nothing. Blas is a header only library.
20+
} // namespace math
21+
} // namespace operators
22+
} // namespace paddle

paddle/fluid/operators/math/blas.h

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/operator.h"
18+
#include "paddle/fluid/framework/tensor.h"
19+
20+
#ifdef PADDLE_WITH_MKLML
21+
#include <mkl_cblas.h>
22+
#include <mkl_lapacke.h>
23+
#include <mkl_vml_functions.h>
24+
#endif
25+
26+
#ifdef PADDLE_USE_OPENBLAS
27+
#include <cblas.h>
28+
#include <lapacke.h>
29+
#endif
30+
31+
#ifndef LAPACK_FOUND
32+
extern "C" {
33+
#include <cblas.h> // NOLINT
34+
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
35+
int* ipiv);
36+
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
37+
int* ipiv);
38+
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
39+
const int* ipiv);
40+
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
41+
const int* ipiv);
42+
}
43+
#endif
44+
45+
namespace paddle {
46+
namespace operators {
47+
namespace math {
48+
49+
template <typename DeviceContext>
50+
class Blas {
51+
public:
52+
explicit Blas(const DeviceContext& context) : context_(context) {}
53+
54+
template <typename T>
55+
void GEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N, int K,
56+
T alpha, const T* A, const T* B, T beta, T* C) const;
57+
58+
template <typename T>
59+
void GEMM(bool transA, bool transB, int M, int N, int K, T alpha, const T* A,
60+
int lda, const T* B, int ldb, T beta, T* C, int ldc) const;
61+
62+
template <typename T>
63+
void MatMul(const framework::Tensor& mat_a, bool trans_a,
64+
const framework::Tensor& mat_b, bool trans_b, T alpha,
65+
framework::Tensor* mat_out, T beta) const;
66+
67+
template <typename T>
68+
void MatMul(const framework::Tensor& mat_a, bool trans_a,
69+
const framework::Tensor& mat_b, bool trans_b,
70+
framework::Tensor* mat_out) const {
71+
MatMul(mat_a, trans_a, mat_b, trans_b, static_cast<T>(1.0), mat_out,
72+
static_cast<T>(0.0));
73+
}
74+
75+
template <typename T>
76+
void MatMul(const framework::Tensor& mat_a, const framework::Tensor& mat_b,
77+
framework::Tensor* mat_out) const {
78+
this->template MatMul<T>(mat_a, false, mat_b, false, mat_out);
79+
}
80+
81+
template <typename T>
82+
void AXPY(int n, T alpha, const T* x, T* y) const;
83+
84+
template <typename T>
85+
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
86+
T* C) const;
87+
88+
template <typename T>
89+
void BatchedGEMM(CBLAS_TRANSPOSE transA, CBLAS_TRANSPOSE transB, int M, int N,
90+
int K, T alpha, const T* A, const T* B, T beta, T* C,
91+
int batchCount, int64_t strideA, int64_t strideB) const;
92+
93+
private:
94+
const DeviceContext& context_;
95+
};
96+
97+
template <typename DeviceContext, typename T>
98+
class BlasT : private Blas<DeviceContext> {
99+
public:
100+
using Blas<DeviceContext>::Blas;
101+
102+
template <typename... ARGS>
103+
void GEMM(ARGS... args) const {
104+
Base()->template GEMM<T>(args...);
105+
}
106+
107+
template <typename... ARGS>
108+
void MatMul(ARGS... args) const {
109+
Base()->template MatMul<T>(args...);
110+
}
111+
112+
template <typename... ARGS>
113+
void AXPY(ARGS... args) const {
114+
Base()->template AXPY<T>(args...);
115+
}
116+
117+
template <typename... ARGS>
118+
void GEMV(ARGS... args) const {
119+
Base()->template GEMV<T>(args...);
120+
}
121+
122+
template <typename... ARGS>
123+
void BatchedGEMM(ARGS... args) const {
124+
Base()->template BatchedGEMM<T>(args...);
125+
}
126+
127+
private:
128+
const Blas<DeviceContext>* Base() const {
129+
return static_cast<const Blas<DeviceContext>*>(this);
130+
}
131+
};
132+
133+
template <typename DeviceContext, typename T>
134+
inline BlasT<DeviceContext, T> GetBlas(
135+
const framework::ExecutionContext& exe_ctx) {
136+
return BlasT<DeviceContext, T>(
137+
exe_ctx.template device_context<DeviceContext>());
138+
}
139+
140+
template <typename DeviceContext, typename T>
141+
inline BlasT<DeviceContext, T> GetBlas(const DeviceContext& dev_ctx) {
142+
return BlasT<DeviceContext, T>(dev_ctx);
143+
}
144+
145+
} // namespace math
146+
} // namespace operators
147+
} // namespace paddle
148+
149+
#include "paddle/fluid/operators/math/blas_impl.h"
150+
#ifdef PADDLE_WITH_CUDA
151+
#include "paddle/fluid/operators/math/blas_impl.cu.h"
152+
#endif

0 commit comments

Comments
 (0)