@@ -307,6 +307,29 @@ void vAdd(const int n, const T* a, const T* b, T* r) {
307
307
n);
308
308
}
309
309
310
+ DEFINE_MATRIX_BINARY_OP (vInvSqrt, b = 1 .0f / std::sqrt(a));
311
+ template <class T >
312
+ void vInvSqrt (const int n, const T* a, T* r) {
313
+ hl_cpu_apply_binary_op<T, binary::vInvSqrt<T>, 0 , 0 >(
314
+ binary::vInvSqrt<T>(), const_cast <T*>(a), r, 1 , n, n, n);
315
+ }
316
+
317
+ DEFINE_MATRIX_BINARY_OP (vLog1p, b = std::log(1 .0f + a));
318
+ template <class T >
319
+ void vLog1p (const int n, const T* a, T* r) {
320
+ hl_cpu_apply_binary_op<T, binary::vLog1p<T>, 0 , 0 >(
321
+ binary::vLog1p<T>(), const_cast <T*>(a), r, 1 , n, n, n);
322
+ }
323
+
324
+ DEFINE_MATRIX_BINARY_OP (vTanh, T tmp = -2.0 * a;
325
+ tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
326
+ b = 2.0 / (1.0 + std::exp(tmp)) - 1.0 );
327
+ template <class T >
328
+ void vTanh (const int n, const T* a, T* r) {
329
+ hl_cpu_apply_binary_op<T, binary::vTanh<T>, 0 , 0 >(
330
+ binary::vTanh<T>(), const_cast <T*>(a), r, 1 , n, n, n);
331
+ }
332
+
310
333
template void vExp (const int n, const float * a, float * r);
311
334
template void vExp (const int n, const double * a, double * r);
312
335
template void vLog (const int n, const float * a, float * r);
@@ -315,6 +338,11 @@ template void vPow(const int n, const float* a, const float b, float* r);
315
338
template void vPow (const int n, const double * a, const double b, double * r);
316
339
template void vAdd (const int n, const float * a, const float * b, float * r);
317
340
template void vAdd (const int n, const double * a, const double * b, double * r);
318
-
341
+ template void vInvSqrt (const int n, const double * a, double * r);
342
+ template void vInvSqrt (const int n, const float * a, float * r);
343
+ template void vLog1p (const int n, const float * a, float * r);
344
+ template void vLog1p (const int n, const double * a, double * r);
345
+ template void vTanh (const int n, const float * a, float * r);
346
+ template void vTanh (const int n, const double * a, double * r);
319
347
#endif
320
348
} // namespace paddle
0 commit comments