@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " MathFunctions.h"
15
+ #include " paddle/math/ MathFunctions.h"
16
16
#include " hl_matrix_apply.cuh"
17
17
#include " hl_matrix_ops.cuh"
18
18
#include " paddle/utils/DynamicLoader.h"
@@ -240,6 +240,36 @@ template <>
240
240
void vAdd<double >(const int n, const double * a, const double * b, double * r) {
241
241
vdAdd (n, a, b, r);
242
242
}
243
+
244
+ template <>
245
+ void vTanh<float >(const int n, const float * a, float * r) {
246
+ vsTanh (n, a, r);
247
+ }
248
+
249
+ template <>
250
+ void vTanh<double >(const int n, const double * a, double * r) {
251
+ vdTanh (n, a, r);
252
+ }
253
+
254
+ template <>
255
+ void vInvSqrt<float >(const int n, const float * a, float * r) {
256
+ vsInvSqrt (n, a, r);
257
+ }
258
+
259
+ template <>
260
+ void vInvSqrt<double >(const int n, const double * a, double * r) {
261
+ vdInvSqrt (n, a, r);
262
+ }
263
+
264
+ template <>
265
+ void vLog1p<float >(const int n, const float * a, float * r) {
266
+ vsLog1p (n, a, r);
267
+ }
268
+
269
+ template <>
270
+ void vLog1p<double >(const int n, const double * a, double * r) {
271
+ vdLog1p (n, a, r);
272
+ }
243
273
#else
244
274
245
275
DEFINE_MATRIX_BINARY_OP (vExp, b = std::exp(a));
@@ -277,17 +307,6 @@ void vAdd(const int n, const T* a, const T* b, T* r) {
277
307
n);
278
308
}
279
309
280
- template void vExp (const int n, const float * a, float * r);
281
- template void vExp (const int n, const double * a, double * r);
282
- template void vLog (const int n, const float * a, float * r);
283
- template void vLog (const int n, const double * a, double * r);
284
- template void vPow (const int n, const float * a, const float b, float * r);
285
- template void vPow (const int n, const double * a, const double b, double * r);
286
- template void vAdd (const int n, const float * a, const float * b, float * r);
287
- template void vAdd (const int n, const double * a, const double * b, double * r);
288
-
289
- #endif
290
-
291
310
DEFINE_MATRIX_BINARY_OP (vInvSqrt, b = 1 .0f / std::sqrt(a));
292
311
template <class T >
293
312
void vInvSqrt (const int n, const T* a, T* r) {
@@ -311,11 +330,19 @@ void vTanh(const int n, const T* a, T* r) {
311
330
binary::vTanh<T>(), const_cast <T*>(a), r, 1 , n, n, n);
312
331
}
313
332
333
+ template void vExp (const int n, const float * a, float * r);
334
+ template void vExp (const int n, const double * a, double * r);
335
+ template void vLog (const int n, const float * a, float * r);
336
+ template void vLog (const int n, const double * a, double * r);
337
+ template void vPow (const int n, const float * a, const float b, float * r);
338
+ template void vPow (const int n, const double * a, const double b, double * r);
339
+ template void vAdd (const int n, const float * a, const float * b, float * r);
340
+ template void vAdd (const int n, const double * a, const double * b, double * r);
314
341
template void vInvSqrt (const int n, const double * a, double * r);
315
342
template void vInvSqrt (const int n, const float * a, float * r);
316
343
template void vLog1p (const int n, const float * a, float * r);
317
344
template void vLog1p (const int n, const double * a, double * r);
318
345
template void vTanh (const int n, const float * a, float * r);
319
346
template void vTanh (const int n, const double * a, double * r);
320
-
347
+ # endif
321
348
} // namespace paddle
0 commit comments