Skip to content

Commit 1be85d0

Browse files
committed
add mkl vsqr and vpow
1 parent 38f499d commit 1be85d0

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

paddle/fluid/operators/math/blas.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ class Blas {
152152
template <typename T>
153153
void VEXP(int n, const T* x, T* y) const;
154154

155+
template <typename T>
156+
void VSQR(int n, const T* x, T* y) const;
157+
158+
template <typename T>
159+
void VPOW(int n, const T* x, T alpha, T* y) const;
160+
155161
template <typename T>
156162
void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
157163
T* C) const;
@@ -238,6 +244,16 @@ class BlasT : private Blas<DeviceContext> {
238244
Base()->template VEXP<T>(args...);
239245
}
240246

247+
template <typename... ARGS>
248+
void VSQR(ARGS... args) const {
249+
Base()->template VSQR<T>(args...);
250+
}
251+
252+
template <typename... ARGS>
253+
void VPOW(ARGS... args) const {
254+
Base()->template VPOW<T>(args...);
255+
}
256+
241257
template <typename... ARGS>
242258
void GEMV(ARGS... args) const {
243259
Base()->template GEMV<T>(args...);

paddle/fluid/operators/math/blas_impl.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414
#pragma once
15+
#include <cmath>
1516
#include <limits>
1617
#include <vector>
1718
#include "paddle/fluid/operators/math/math_function.h"
@@ -102,6 +103,16 @@ struct CBlas<float> {
102103
static void VEXP(ARGS... args) {
103104
platform::dynload::vsExp(args...);
104105
}
106+
107+
template <typename... ARGS>
108+
static void VSQR(ARGS... args) {
109+
platform::dynload::vsSqr(args...);
110+
}
111+
112+
template <typename... ARGS>
113+
static void VPOW(ARGS... args) {
114+
platform::dynload::vsPowx(args...);
115+
}
105116
};
106117

107118
template <>
@@ -182,6 +193,16 @@ struct CBlas<double> {
182193
static void VEXP(ARGS... args) {
183194
platform::dynload::vdExp(args...);
184195
}
196+
197+
template <typename... ARGS>
198+
static void VSQR(ARGS... args) {
199+
platform::dynload::vdSqr(args...);
200+
}
201+
202+
template <typename... ARGS>
203+
static void VPOW(ARGS... args) {
204+
platform::dynload::vdPowx(args...);
205+
}
185206
};
186207

187208
#else
@@ -241,6 +262,8 @@ struct CBlas<platform::float16> {
241262
}
242263
static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
243264
static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
265+
static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); }
266+
static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
244267
static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
245268
static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
246269
#ifdef PADDLE_WITH_MKLML
@@ -398,6 +421,31 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
398421
#endif
399422
}
400423

424+
template <>
425+
template <typename T>
426+
void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const {
427+
#ifdef PADDLE_WITH_MKLML
428+
CBlas<T>::VSQR(n, x, y);
429+
#else
430+
for (int i = 0; i < n; ++i) {
431+
y[i] = std::sqrt(x[i]);
432+
}
433+
#endif
434+
}
435+
436+
template <>
437+
template <typename T>
438+
void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a,
439+
T *y) const {
440+
#ifdef PADDLE_WITH_MKLML
441+
CBlas<T>::VPOW(n, x, a, y);
442+
#else
443+
for (int i = 0; i < n; ++i) {
444+
y[i] = std::pow(x[i], a);
445+
}
446+
#endif
447+
}
448+
401449
template <>
402450
template <typename T>
403451
T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {

paddle/fluid/platform/dynload/mklml.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ extern void* mklml_dso_handle;
7676
__macro(vdMul); \
7777
__macro(vsExp); \
7878
__macro(vdExp); \
79+
__macro(vsSqr); \
80+
__macro(vdSqr); \
81+
__macro(vsPowx); \
82+
__macro(vdPowx); \
7983
__macro(MKL_Set_Num_Threads)
8084

8185
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);

0 commit comments

Comments
 (0)