|
13 | 13 | import scipy.stats
|
14 | 14 |
|
15 | 15 | from pytensor.configdefaults import config
|
16 |
| -from pytensor.gradient import grad_not_implemented |
| 16 | +from pytensor.gradient import grad_not_implemented, grad_undefined |
17 | 17 | from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp
|
18 | 18 | from pytensor.scalar.basic import abs as scalar_abs
|
19 | 19 | from pytensor.scalar.basic import (
|
@@ -473,8 +473,12 @@ def st_impl(x):
|
473 | 473 | def impl(self, x):
|
474 | 474 | return TriGamma.st_impl(x)
|
475 | 475 |
|
476 |
| - def grad(self, inputs, outputs_gradients): |
477 |
| - raise NotImplementedError() |
| 476 | + def L_op(self, inputs, outputs, outputs_gradients): |
| 477 | + (x,) = inputs |
| 478 | + (g_out,) = outputs_gradients |
| 479 | + if x in complex_types: |
| 480 | + raise NotImplementedError("gradient not implemented for complex types") |
| 481 | + return [g_out * polygamma(2, x)] |
478 | 482 |
|
479 | 483 | def c_support_code(self, **kwargs):
|
480 | 484 | # The implementation has been copied from
|
@@ -541,7 +545,52 @@ def c_code(self, node, name, inp, out, sub):
|
541 | 545 | raise NotImplementedError("only floating point is implemented")
|
542 | 546 |
|
543 | 547 |
|
544 |
| -tri_gamma = TriGamma(upgrade_to_float, name="tri_gamma") |
| 548 | +# Scipy polygamma does not support complex inputs: https://github.com/scipy/scipy/issues/7410 |
| 549 | +tri_gamma = TriGamma(upgrade_to_float_no_complex, name="tri_gamma") |
| 550 | + |
| 551 | + |
| 552 | +class PolyGamma(BinaryScalarOp): |
| 553 | + """Polygamma function of order n evaluated at x. |
| 554 | +
|
| 555 | + It corresponds to the (n+1)th derivative of the log gamma function. |
| 556 | +
|
| 557 | + TODO: Because the first input is discrete and the output is continuous, |
| 558 | + the default elemwise inplace won't work, as it always tries to store the results in the first input. |
| 559 | + """ |
| 560 | + |
| 561 | + nfunc_spec = ("scipy.special.polygamma", 2, 1) |
| 562 | + |
| 563 | + @staticmethod |
| 564 | + def output_types_preference(n_type, x_type): |
| 565 | + if n_type not in discrete_types: |
| 566 | + raise TypeError( |
| 567 | + f"Polygamma order parameter must be discrete, got {n_type} dtype" |
| 568 | + ) |
| 569 | + # Scipy doesn't support it |
| 570 | + return upgrade_to_float_no_complex(x_type) |
| 571 | + |
| 572 | + @staticmethod |
| 573 | + def st_impl(n, x): |
| 574 | + return scipy.special.polygamma(n, x) |
| 575 | + |
| 576 | + def impl(self, n, x): |
| 577 | + return PolyGamma.st_impl(n, x) |
| 578 | + |
| 579 | + def L_op(self, inputs, outputs, output_gradients): |
| 580 | + (n, x) = inputs |
| 581 | + (g_out,) = output_gradients |
| 582 | + if x in complex_types: |
| 583 | + raise NotImplementedError("gradient not implemented for complex types") |
| 584 | + return [ |
| 585 | + grad_undefined(self, 0, n), |
| 586 | + g_out * self(n + 1, x), |
| 587 | + ] |
| 588 | + |
| 589 | + def c_code(self, *args, **kwargs): |
| 590 | + raise NotImplementedError() |
| 591 | + |
| 592 | + |
| 593 | +polygamma = PolyGamma(name="polygamma") |
545 | 594 |
|
546 | 595 |
|
547 | 596 | class Chi2SF(BinaryScalarOp):
|
|
0 commit comments