Skip to content

Commit 537f57a

Browse files
committed
fix undefine error on gpu
1 parent 315e08e commit 537f57a

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

paddle/math/MathFunctions.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,29 @@ void vAdd(const int n, const T* a, const T* b, T* r) {
307307
n);
308308
}
309309

310+
DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
311+
template <class T>
312+
void vInvSqrt(const int n, const T* a, T* r) {
313+
hl_cpu_apply_binary_op<T, binary::vInvSqrt<T>, 0, 0>(
314+
binary::vInvSqrt<T>(), const_cast<T*>(a), r, 1, n, n, n);
315+
}
316+
317+
DEFINE_MATRIX_BINARY_OP(vLog1p, b = std::log(1.0f + a));
318+
template <class T>
319+
void vLog1p(const int n, const T* a, T* r) {
320+
hl_cpu_apply_binary_op<T, binary::vLog1p<T>, 0, 0>(
321+
binary::vLog1p<T>(), const_cast<T*>(a), r, 1, n, n, n);
322+
}
323+
324+
DEFINE_MATRIX_BINARY_OP(vTanh, T tmp = -2.0 * a;
325+
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
326+
b = 2.0 / (1.0 + std::exp(tmp)) - 1.0);
327+
template <class T>
328+
void vTanh(const int n, const T* a, T* r) {
329+
hl_cpu_apply_binary_op<T, binary::vTanh<T>, 0, 0>(
330+
binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n);
331+
}
332+
310333
template void vExp(const int n, const float* a, float* r);
311334
template void vExp(const int n, const double* a, double* r);
312335
template void vLog(const int n, const float* a, float* r);
@@ -315,6 +338,11 @@ template void vPow(const int n, const float* a, const float b, float* r);
315338
template void vPow(const int n, const double* a, const double b, double* r);
316339
template void vAdd(const int n, const float* a, const float* b, float* r);
317340
template void vAdd(const int n, const double* a, const double* b, double* r);
318-
341+
template void vInvSqrt(const int n, const double* a, double* r);
342+
template void vInvSqrt(const int n, const float* a, float* r);
343+
template void vLog1p(const int n, const float* a, float* r);
344+
template void vLog1p(const int n, const double* a, double* r);
345+
template void vTanh(const int n, const float* a, float* r);
346+
template void vTanh(const int n, const double* a, double* r);
319347
#endif
320348
} // namespace paddle

0 commit comments

Comments
 (0)