File tree Expand file tree Collapse file tree 2 files changed +47
-3
lines changed Expand file tree Collapse file tree 2 files changed +47
-3
lines changed Original file line number Diff line number Diff line change @@ -389,14 +389,19 @@ def c_support_code(self, **kwargs):
389389 #define ga_double double
390390 #endif
391391
392+ #ifndef M_PI
393+ #define M_PI 3.14159265358979323846
394+ #endif
395+
392396 #ifndef _PSIFUNCDEFINED
393397 #define _PSIFUNCDEFINED
394398 DEVICE double _psi(ga_double x) {
395399
396400 /*taken from
397401 Bernardo, J. M. (1976). Algorithm AS 103:
398402 Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
399- http://www.uv.es/~bernardo/1976AppStatist.pdf */
403+ http://www.uv.es/~bernardo/1976AppStatist.pdf
404+ */
400405
401406 ga_double y, R, psi_ = 0;
402407 ga_double S = 1.0e-5;
@@ -406,10 +411,22 @@ def c_support_code(self, **kwargs):
406411 ga_double S5 = 3.968253968e-3;
407412 ga_double D1 = -0.5772156649;
408413
414+ if (x <= 0) {
415+ // the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
416+ if (x == floor(x)) {
417+ return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
418+ }
419+
420+ // Use reflection formula
421+ ga_double pi_x = M_PI * x;
422+ ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
423+ return _psi(1.0 - x) + M_PI * cot_pi_x;
424+ }
425+
409426 y = x;
410427
411- if (y <= 0.0 )
412- return psi_;
428+ if (y <= 0)
429+ return psi_;
413430
414431 if (y <= S)
415432 return D1 - 1.0/y;
Original file line number Diff line number Diff line change 22
33import numpy as np
44import pytest
5+ import scipy
56import scipy .special as sp
67
78import pytensor .tensor as pt
1920 gammal ,
2021 gammau ,
2122 hyp2f1 ,
23+ psi ,
2224)
2325from tests .link .test_link import make_function
2426
@@ -149,3 +151,28 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
149151 (var .owner and isinstance (var .owner .op , ScalarLoop ))
150152 for var in ancestors (grad )
151153 )
154+
155+
156+ @pytest .mark .parametrize (
157+ "linker" ,
158+ ["py" , "c" ],
159+ )
160+ def test_psi (linker ):
161+ x = float64 ("x" )
162+ out = psi (x )
163+
164+ fn = function ([x ], out , mode = Mode (linker = linker , optimizer = "fast_run" ))
165+ fn .dprint ()
166+
167+ x_test = np .float64 (0.5 )
168+
169+ np .testing .assert_allclose (
170+ fn (x_test ),
171+ scipy .special .psi (x_test ),
172+ strict = True ,
173+ )
174+ np .testing .assert_allclose (
175+ fn (- x_test ),
176+ scipy .special .psi (- x_test ),
177+ strict = True ,
178+ )
You can’t perform that action at this time.
0 commit comments