@@ -378,17 +378,6 @@ def L_op(self, inputs, outputs, grads):
378378
379379 def c_support_code (self , ** kwargs ):
380380 return """
381- // For GPU support
382- #ifdef WITHIN_KERNEL
383- #define DEVICE WITHIN_KERNEL
384- #else
385- #define DEVICE
386- #endif
387-
388- #ifndef ga_double
389- #define ga_double double
390- #endif
391-
392381 #ifndef M_PI
393382 #define M_PI 3.14159265358979323846
394383 #endif
@@ -397,51 +386,48 @@ def c_support_code(self, **kwargs):
397386 #define _PSIFUNCDEFINED
398387 DEVICE double _psi(ga_double x) {
399388
400- /*taken from
401- Bernardo, J. M. (1976). Algorithm AS 103:
402- Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
403- http://www.uv.es/~bernardo/1976AppStatist.pdf
404- */
405-
406- ga_double y, R, psi_ = 0;
407- ga_double S = 1.0e-5;
408- ga_double C = 8.5;
409- ga_double S3 = 8.333333333e-2;
410- ga_double S4 = 8.333333333e-3;
411- ga_double S5 = 3.968253968e-3;
412- ga_double D1 = -0.5772156649;
413-
414- if (x <= 0) {
415- // the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
416- if (x == floor(x)) {
417- return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
418- }
389+ /*taken from
390+ Bernardo, J. M. (1976). Algorithm AS 103:
391+ Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
392+ http://www.uv.es/~bernardo/1976AppStatist.pdf
393+ */
419394
420- // Use reflection formula
421- ga_double pi_x = M_PI * x;
422- ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
423- return _psi(1.0 - x) + M_PI * cot_pi_x;
424- }
395+ double y, R, psi_ = 0;
396+ double S = 1.0e-5;
397+ double C = 8.5;
398+ double S3 = 8.333333333e-2;
399+ double S4 = 8.333333333e-3;
400+ double S5 = 3.968253968e-3;
401+ double D1 = -0.5772156649;
425402
426- y = x;
403+ if (x <= 0) {
404+ // the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
405+ if (x == floor(x)) {
406+ return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
407+ }
408+
409+ // Use reflection formula
410+ ga_double pi_x = M_PI * x;
411+ ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
412+ return _psi(1.0 - x) - M_PI * cot_pi_x;
413+ }
427414
428- if (y <= 0)
429- return psi_;
415+ y = x;
430416
431- if (y <= S)
432- return D1 - 1.0/y;
417+ if (y <= S)
418+ return D1 - 1.0/y;
433419
434- while (y < C) {
435- psi_ = psi_ - 1.0 / y;
436- y = y + 1;
437- }
420+ while (y < C) {
421+ psi_ = psi_ - 1.0 / y;
422+ y = y + 1;
423+ }
438424
439- R = 1.0 / y;
440- psi_ = psi_ + log(y) - .5 * R ;
441- R= R*R;
442- psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
425+ R = 1.0 / y;
426+ psi_ = psi_ + log(y) - .5 * R ;
427+ R= R*R;
428+ psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
443429
444- return psi_;
430+ return psi_;
445431 }
446432 #endif
447433 """
@@ -450,8 +436,8 @@ def c_code(self, node, name, inp, out, sub):
450436 (x ,) = inp
451437 (z ,) = out
452438 if node .inputs [0 ].type in float_types :
453- return f""" { z } =
454- _psi({ x } );"" "
439+ dtype = "npy_" + node . outputs [ 0 ]. dtype
440+ return f" { z } = ( { dtype } ) _psi({ x } );"
455441 raise NotImplementedError ("only floating point is implemented" )
456442
457443
0 commit comments