Skip to content

Commit 25aa453

Browse files
authored
Merge pull request #10934 from tensor-tang/mklml_funcs
speedup vInvSqrt vLogqp vTanh with mklml
2 parents 9d723b8 + afbc4ce commit 25aa453

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

paddle/math/MathFunctions.cpp

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "MathFunctions.h"
15+
#include "paddle/math/MathFunctions.h"
1616
#include "hl_matrix_apply.cuh"
1717
#include "hl_matrix_ops.cuh"
1818
#include "paddle/utils/DynamicLoader.h"
@@ -240,6 +240,36 @@ template <>
240240
void vAdd<double>(const int n, const double* a, const double* b, double* r) {
241241
vdAdd(n, a, b, r);
242242
}
243+
244+
template <>
245+
void vTanh<float>(const int n, const float* a, float* r) {
246+
vsTanh(n, a, r);
247+
}
248+
249+
template <>
250+
void vTanh<double>(const int n, const double* a, double* r) {
251+
vdTanh(n, a, r);
252+
}
253+
254+
template <>
255+
void vInvSqrt<float>(const int n, const float* a, float* r) {
256+
vsInvSqrt(n, a, r);
257+
}
258+
259+
template <>
260+
void vInvSqrt<double>(const int n, const double* a, double* r) {
261+
vdInvSqrt(n, a, r);
262+
}
263+
264+
template <>
265+
void vLog1p<float>(const int n, const float* a, float* r) {
266+
vsLog1p(n, a, r);
267+
}
268+
269+
template <>
270+
void vLog1p<double>(const int n, const double* a, double* r) {
271+
vdLog1p(n, a, r);
272+
}
243273
#else
244274

245275
DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a));
@@ -277,17 +307,6 @@ void vAdd(const int n, const T* a, const T* b, T* r) {
277307
n);
278308
}
279309

280-
template void vExp(const int n, const float* a, float* r);
281-
template void vExp(const int n, const double* a, double* r);
282-
template void vLog(const int n, const float* a, float* r);
283-
template void vLog(const int n, const double* a, double* r);
284-
template void vPow(const int n, const float* a, const float b, float* r);
285-
template void vPow(const int n, const double* a, const double b, double* r);
286-
template void vAdd(const int n, const float* a, const float* b, float* r);
287-
template void vAdd(const int n, const double* a, const double* b, double* r);
288-
289-
#endif
290-
291310
DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
292311
template <class T>
293312
void vInvSqrt(const int n, const T* a, T* r) {
@@ -311,11 +330,19 @@ void vTanh(const int n, const T* a, T* r) {
311330
binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n);
312331
}
313332

333+
template void vExp(const int n, const float* a, float* r);
334+
template void vExp(const int n, const double* a, double* r);
335+
template void vLog(const int n, const float* a, float* r);
336+
template void vLog(const int n, const double* a, double* r);
337+
template void vPow(const int n, const float* a, const float b, float* r);
338+
template void vPow(const int n, const double* a, const double b, double* r);
339+
template void vAdd(const int n, const float* a, const float* b, float* r);
340+
template void vAdd(const int n, const double* a, const double* b, double* r);
314341
template void vInvSqrt(const int n, const double* a, double* r);
315342
template void vInvSqrt(const int n, const float* a, float* r);
316343
template void vLog1p(const int n, const float* a, float* r);
317344
template void vLog1p(const int n, const double* a, double* r);
318345
template void vTanh(const int n, const float* a, float* r);
319346
template void vTanh(const int n, const double* a, double* r);
320-
347+
#endif
321348
} // namespace paddle

0 commit comments

Comments
 (0)