File tree Expand file tree Collapse file tree 2 files changed +47
-3
lines changed Expand file tree Collapse file tree 2 files changed +47
-3
lines changed Original file line number Diff line number Diff line change @@ -389,14 +389,19 @@ def c_support_code(self, **kwargs):
389
389
#define ga_double double
390
390
#endif
391
391
392
+ #ifndef M_PI
393
+ #define M_PI 3.14159265358979323846
394
+ #endif
395
+
392
396
#ifndef _PSIFUNCDEFINED
393
397
#define _PSIFUNCDEFINED
394
398
DEVICE double _psi(ga_double x) {
395
399
396
400
/*taken from
397
401
Bernardo, J. M. (1976). Algorithm AS 103:
398
402
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
399
- http://www.uv.es/~bernardo/1976AppStatist.pdf */
403
+ http://www.uv.es/~bernardo/1976AppStatist.pdf
404
+ */
400
405
401
406
ga_double y, R, psi_ = 0;
402
407
ga_double S = 1.0e-5;
@@ -406,10 +411,22 @@ def c_support_code(self, **kwargs):
406
411
ga_double S5 = 3.968253968e-3;
407
412
ga_double D1 = -0.5772156649;
408
413
414
+ if (x <= 0) {
415
+ // the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
416
+ if (x == floor(x)) {
417
+ return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
418
+ }
419
+
420
+ // Use reflection formula
421
+ ga_double pi_x = M_PI * x;
422
+ ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
423
+ return _psi(1.0 - x) + M_PI * cot_pi_x;
424
+ }
425
+
409
426
y = x;
410
427
411
- if (y <= 0.0 )
412
- return psi_;
428
+ if (y <= 0)
429
+ return psi_;
413
430
414
431
if (y <= S)
415
432
return D1 - 1.0/y;
Original file line number Diff line number Diff line change 2
2
3
3
import numpy as np
4
4
import pytest
5
+ import scipy
5
6
import scipy .special as sp
6
7
7
8
import pytensor .tensor as pt
19
20
gammal ,
20
21
gammau ,
21
22
hyp2f1 ,
23
+ psi ,
22
24
)
23
25
from tests .link .test_link import make_function
24
26
@@ -149,3 +151,28 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
149
151
(var .owner and isinstance (var .owner .op , ScalarLoop ))
150
152
for var in ancestors (grad )
151
153
)
154
+
155
+
156
+ @pytest .mark .parametrize (
157
+ "linker" ,
158
+ ["py" , "c" ],
159
+ )
160
+ def test_psi (linker ):
161
+ x = float64 ("x" )
162
+ out = psi (x )
163
+
164
+ fn = function ([x ], out , mode = Mode (linker = linker , optimizer = "fast_run" ))
165
+ fn .dprint ()
166
+
167
+ x_test = np .float64 (0.5 )
168
+
169
+ np .testing .assert_allclose (
170
+ fn (x_test ),
171
+ scipy .special .psi (x_test ),
172
+ strict = True ,
173
+ )
174
+ np .testing .assert_allclose (
175
+ fn (- x_test ),
176
+ scipy .special .psi (- x_test ),
177
+ strict = True ,
178
+ )
You can’t perform that action at this time.
0 commit comments