|
12 | 12 | // See the License for the specific language governing permissions and
|
13 | 13 | // limitations under the License.
|
14 | 14 | #pragma once
|
| 15 | +#include <cmath> |
15 | 16 | #include <limits>
|
16 | 17 | #include <vector>
|
17 | 18 | #include "paddle/fluid/operators/math/math_function.h"
|
@@ -102,6 +103,16 @@ struct CBlas<float> {
|
102 | 103 | static void VEXP(ARGS... args) {
|
103 | 104 | platform::dynload::vsExp(args...);
|
104 | 105 | }
|
| 106 | + |
| 107 | + template <typename... ARGS> |
| 108 | + static void VSQR(ARGS... args) { |
| 109 | + platform::dynload::vsSqr(args...); |
| 110 | + } |
| 111 | + |
| 112 | + template <typename... ARGS> |
| 113 | + static void VPOW(ARGS... args) { |
| 114 | + platform::dynload::vsPowx(args...); |
| 115 | + } |
105 | 116 | };
|
106 | 117 |
|
107 | 118 | template <>
|
@@ -182,6 +193,16 @@ struct CBlas<double> {
|
182 | 193 | static void VEXP(ARGS... args) {
|
183 | 194 | platform::dynload::vdExp(args...);
|
184 | 195 | }
|
| 196 | + |
| 197 | + template <typename... ARGS> |
| 198 | + static void VSQR(ARGS... args) { |
| 199 | + platform::dynload::vdSqr(args...); |
| 200 | + } |
| 201 | + |
| 202 | + template <typename... ARGS> |
| 203 | + static void VPOW(ARGS... args) { |
| 204 | + platform::dynload::vdPowx(args...); |
| 205 | + } |
185 | 206 | };
|
186 | 207 |
|
187 | 208 | #else
|
@@ -241,6 +262,8 @@ struct CBlas<platform::float16> {
|
241 | 262 | }
|
242 | 263 | static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
|
243 | 264 | static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
|
| 265 | + static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); } |
| 266 | + static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); } |
244 | 267 | static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
|
245 | 268 | static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
|
246 | 269 | #ifdef PADDLE_WITH_MKLML
|
@@ -398,6 +421,31 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
|
398 | 421 | #endif
|
399 | 422 | }
|
400 | 423 |
|
| 424 | +template <> |
| 425 | +template <typename T> |
| 426 | +void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const { |
| 427 | +#ifdef PADDLE_WITH_MKLML |
| 428 | + CBlas<T>::VSQR(n, x, y); |
| 429 | +#else |
| 430 | + for (int i = 0; i < n; ++i) { |
| 431 | + y[i] = std::sqrt(x[i]); |
| 432 | + } |
| 433 | +#endif |
| 434 | +} |
| 435 | + |
| 436 | +template <> |
| 437 | +template <typename T> |
| 438 | +void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a, |
| 439 | + T *y) const { |
| 440 | +#ifdef PADDLE_WITH_MKLML |
| 441 | + CBlas<T>::VPOW(n, x, a, y); |
| 442 | +#else |
| 443 | + for (int i = 0; i < n; ++i) { |
| 444 | + y[i] = std::pow(x[i], a); |
| 445 | + } |
| 446 | +#endif |
| 447 | +} |
| 448 | + |
401 | 449 | template <>
|
402 | 450 | template <typename T>
|
403 | 451 | T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
|
|
0 commit comments