@@ -378,17 +378,6 @@ def L_op(self, inputs, outputs, grads):
378
378
379
379
def c_support_code (self , ** kwargs ):
380
380
return """
381
- // For GPU support
382
- #ifdef WITHIN_KERNEL
383
- #define DEVICE WITHIN_KERNEL
384
- #else
385
- #define DEVICE
386
- #endif
387
-
388
- #ifndef ga_double
389
- #define ga_double double
390
- #endif
391
-
392
381
#ifndef M_PI
393
382
#define M_PI 3.14159265358979323846
394
383
#endif
@@ -397,51 +386,48 @@ def c_support_code(self, **kwargs):
397
386
#define _PSIFUNCDEFINED
398
387
DEVICE double _psi(ga_double x) {
399
388
400
- /*taken from
401
- Bernardo, J. M. (1976). Algorithm AS 103:
402
- Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
403
- http://www.uv.es/~bernardo/1976AppStatist.pdf
404
- */
405
-
406
- ga_double y, R, psi_ = 0;
407
- ga_double S = 1.0e-5;
408
- ga_double C = 8.5;
409
- ga_double S3 = 8.333333333e-2;
410
- ga_double S4 = 8.333333333e-3;
411
- ga_double S5 = 3.968253968e-3;
412
- ga_double D1 = -0.5772156649;
413
-
414
- if (x <= 0) {
415
- // the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
416
- if (x == floor(x)) {
417
- return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
418
- }
389
+ /*taken from
390
+ Bernardo, J. M. (1976). Algorithm AS 103:
391
+ Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
392
+ http://www.uv.es/~bernardo/1976AppStatist.pdf
393
+ */
419
394
420
- // Use reflection formula
421
- ga_double pi_x = M_PI * x;
422
- ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
423
- return _psi(1.0 - x) + M_PI * cot_pi_x;
424
- }
395
+ double y, R, psi_ = 0;
396
+ double S = 1.0e-5;
397
+ double C = 8.5;
398
+ double S3 = 8.333333333e-2;
399
+ double S4 = 8.333333333e-3;
400
+ double S5 = 3.968253968e-3;
401
+ double D1 = -0.5772156649;
425
402
426
- y = x;
403
+ if (x <= 0) {
404
+ // the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
405
+ if (x == floor(x)) {
406
+ return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
407
+ }
408
+
409
+ // Use reflection formula
410
+ ga_double pi_x = M_PI * x;
411
+ ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
412
+ return _psi(1.0 - x) - M_PI * cot_pi_x;
413
+ }
427
414
428
- if (y <= 0)
429
- return psi_;
415
+ y = x;
430
416
431
- if (y <= S)
432
- return D1 - 1.0/y;
417
+ if (y <= S)
418
+ return D1 - 1.0/y;
433
419
434
- while (y < C) {
435
- psi_ = psi_ - 1.0 / y;
436
- y = y + 1;
437
- }
420
+ while (y < C) {
421
+ psi_ = psi_ - 1.0 / y;
422
+ y = y + 1;
423
+ }
438
424
439
- R = 1.0 / y;
440
- psi_ = psi_ + log(y) - .5 * R ;
441
- R= R*R;
442
- psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
425
+ R = 1.0 / y;
426
+ psi_ = psi_ + log(y) - .5 * R ;
427
+ R= R*R;
428
+ psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
443
429
444
- return psi_;
430
+ return psi_;
445
431
}
446
432
#endif
447
433
"""
@@ -450,8 +436,8 @@ def c_code(self, node, name, inp, out, sub):
450
436
(x ,) = inp
451
437
(z ,) = out
452
438
if node .inputs [0 ].type in float_types :
453
- return f""" { z } =
454
- _psi({ x } );"" "
439
+ dtype = "npy_" + node . outputs [ 0 ]. dtype
440
+ return f" { z } = ( { dtype } ) _psi({ x } );"
455
441
raise NotImplementedError ("only floating point is implemented" )
456
442
457
443
0 commit comments