Skip to content

Commit 7799bd0

Browse files
Abhinav-KhotricardoV94
authored andcommitted
add negative value support for the digamma function
1 parent 7584614 commit 7799bd0

File tree

2 files changed

+47
-3
lines changed

2 files changed

+47
-3
lines changed

pytensor/scalar/math.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,19 @@ def c_support_code(self, **kwargs):
389389
#define ga_double double
390390
#endif
391391
392+
#ifndef M_PI
393+
#define M_PI 3.14159265358979323846
394+
#endif
395+
392396
#ifndef _PSIFUNCDEFINED
393397
#define _PSIFUNCDEFINED
394398
DEVICE double _psi(ga_double x) {
395399
396400
/*taken from
397401
Bernardo, J. M. (1976). Algorithm AS 103:
398402
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
399-
http://www.uv.es/~bernardo/1976AppStatist.pdf */
403+
http://www.uv.es/~bernardo/1976AppStatist.pdf
404+
*/
400405
401406
ga_double y, R, psi_ = 0;
402407
ga_double S = 1.0e-5;
@@ -406,10 +411,22 @@ def c_support_code(self, **kwargs):
406411
ga_double S5 = 3.968253968e-3;
407412
ga_double D1 = -0.5772156649;
408413
414+
if (x <= 0) {
415+
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
416+
if (x == floor(x)) {
417+
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
418+
}
419+
420+
// Use reflection formula
421+
ga_double pi_x = M_PI * x;
422+
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
423+
return _psi(1.0 - x) + M_PI * cot_pi_x;
424+
}
425+
409426
y = x;
410427
411-
if (y <= 0.0)
412-
return psi_;
428+
if (y <= 0)
429+
return psi_;
413430
414431
if (y <= S)
415432
return D1 - 1.0/y;

tests/scalar/test_math.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
import pytest
5+
import scipy
56
import scipy.special as sp
67

78
import pytensor.tensor as pt
@@ -19,6 +20,7 @@
1920
gammal,
2021
gammau,
2122
hyp2f1,
23+
psi,
2224
)
2325
from tests.link.test_link import make_function
2426

@@ -149,3 +151,28 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
149151
(var.owner and isinstance(var.owner.op, ScalarLoop))
150152
for var in ancestors(grad)
151153
)
154+
155+
156+
@pytest.mark.parametrize(
157+
"linker",
158+
["py", "c"],
159+
)
160+
def test_psi(linker):
161+
x = float64("x")
162+
out = psi(x)
163+
164+
fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
165+
fn.dprint()
166+
167+
x_test = np.float64(0.5)
168+
169+
np.testing.assert_allclose(
170+
fn(x_test),
171+
scipy.special.psi(x_test),
172+
strict=True,
173+
)
174+
np.testing.assert_allclose(
175+
fn(-x_test),
176+
scipy.special.psi(-x_test),
177+
strict=True,
178+
)

0 commit comments

Comments
 (0)