Skip to content

Commit 88476b4

Browse files
Abhinav-KhotricardoV94
authored andcommitted
Add support for negative values in psi
1 parent 7799bd0 commit 88476b4

File tree

2 files changed

+39
-53
lines changed

2 files changed

+39
-53
lines changed

pytensor/scalar/math.py

Lines changed: 37 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -378,17 +378,6 @@ def L_op(self, inputs, outputs, grads):
378378

379379
def c_support_code(self, **kwargs):
380380
return """
381-
// For GPU support
382-
#ifdef WITHIN_KERNEL
383-
#define DEVICE WITHIN_KERNEL
384-
#else
385-
#define DEVICE
386-
#endif
387-
388-
#ifndef ga_double
389-
#define ga_double double
390-
#endif
391-
392381
#ifndef M_PI
393382
#define M_PI 3.14159265358979323846
394383
#endif
@@ -397,51 +386,48 @@ def c_support_code(self, **kwargs):
397386
#define _PSIFUNCDEFINED
398387
DEVICE double _psi(ga_double x) {
399388
400-
/*taken from
401-
Bernardo, J. M. (1976). Algorithm AS 103:
402-
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
403-
http://www.uv.es/~bernardo/1976AppStatist.pdf
404-
*/
405-
406-
ga_double y, R, psi_ = 0;
407-
ga_double S = 1.0e-5;
408-
ga_double C = 8.5;
409-
ga_double S3 = 8.333333333e-2;
410-
ga_double S4 = 8.333333333e-3;
411-
ga_double S5 = 3.968253968e-3;
412-
ga_double D1 = -0.5772156649;
413-
414-
if (x <= 0) {
415-
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
416-
if (x == floor(x)) {
417-
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
418-
}
389+
/*taken from
390+
Bernardo, J. M. (1976). Algorithm AS 103:
391+
Psi (Digamma) Function. Applied Statistics. 25 (3), 315-317.
392+
http://www.uv.es/~bernardo/1976AppStatist.pdf
393+
*/
419394
420-
// Use reflection formula
421-
ga_double pi_x = M_PI * x;
422-
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
423-
return _psi(1.0 - x) + M_PI * cot_pi_x;
424-
}
395+
double y, R, psi_ = 0;
396+
double S = 1.0e-5;
397+
double C = 8.5;
398+
double S3 = 8.333333333e-2;
399+
double S4 = 8.333333333e-3;
400+
double S5 = 3.968253968e-3;
401+
double D1 = -0.5772156649;
425402
426-
y = x;
403+
if (x <= 0) {
404+
// the digamma function approaches infinity from one side and -infinity from the other, around negative integers and zero
405+
if (x == floor(x)) {
406+
return INFINITY; // note that scipy returns -INF for 0 and NaN for negative integers
407+
}
408+
409+
// Use reflection formula
410+
ga_double pi_x = M_PI * x;
411+
ga_double cot_pi_x = cos(pi_x) / sin(pi_x);
412+
return _psi(1.0 - x) - M_PI * cot_pi_x;
413+
}
427414
428-
if (y <= 0)
429-
return psi_;
415+
y = x;
430416
431-
if (y <= S)
432-
return D1 - 1.0/y;
417+
if (y <= S)
418+
return D1 - 1.0/y;
433419
434-
while (y < C) {
435-
psi_ = psi_ - 1.0 / y;
436-
y = y + 1;
437-
}
420+
while (y < C) {
421+
psi_ = psi_ - 1.0 / y;
422+
y = y + 1;
423+
}
438424
439-
R = 1.0 / y;
440-
psi_ = psi_ + log(y) - .5 * R ;
441-
R= R*R;
442-
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
425+
R = 1.0 / y;
426+
psi_ = psi_ + log(y) - .5 * R ;
427+
R= R*R;
428+
psi_ = psi_ - R * (S3 - R * (S4 - R * S5));
443429
444-
return psi_;
430+
return psi_;
445431
}
446432
#endif
447433
"""
@@ -450,8 +436,8 @@ def c_code(self, node, name, inp, out, sub):
450436
(x,) = inp
451437
(z,) = out
452438
if node.inputs[0].type in float_types:
453-
return f"""{z} =
454-
_psi({x});"""
439+
dtype = "npy_" + node.outputs[0].dtype
440+
return f"{z} = ({dtype}) _psi({x});"
455441
raise NotImplementedError("only floating point is implemented")
456442

457443

tests/scalar/test_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def test_scalarloop_grad_mixed_dtypes(op, scalar_loop_grads):
155155

156156
@pytest.mark.parametrize(
157157
"linker",
158-
["py", "c"],
158+
["py", "cvm"],
159159
)
160160
def test_psi(linker):
161161
x = float64("x")
@@ -164,7 +164,7 @@ def test_psi(linker):
164164
fn = function([x], out, mode=Mode(linker=linker, optimizer="fast_run"))
165165
fn.dprint()
166166

167-
x_test = np.float64(0.5)
167+
x_test = np.float64(0.7)
168168

169169
np.testing.assert_allclose(
170170
fn(x_test),

0 commit comments

Comments
 (0)