Skip to content

Commit 902eeb6

Browse files
authored
Implement ChiSquare via Gamma (#490)
1 parent 9653ade commit 902eeb6

File tree

4 files changed

+16
-52
lines changed

4 files changed

+16
-52
lines changed

doc/library/tensor/random/basic.rst

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ PyTensor can produce :class:`RandomVariable`\s that draw samples from many diffe
8282
.. autoclass:: pytensor.tensor.random.basic.CategoricalRV
8383
:members: __call__
8484

85-
.. autoclass:: pytensor.tensor.random.basic.ChiSquareRV
86-
:members: __call__
87-
8885
.. autoclass:: pytensor.tensor.random.basic.DirichletRV
8986
:members: __call__
9087

pytensor/link/numba/dispatch/random.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@ def {sized_fn_name}({random_fn_input_names}):
195195
@numba_funcify.register(aer.NormalRV)
196196
@numba_funcify.register(aer.LogNormalRV)
197197
@numba_funcify.register(aer.GammaRV)
198-
@numba_funcify.register(aer.ChiSquareRV)
199198
@numba_funcify.register(aer.ParetoRV)
200199
@numba_funcify.register(aer.GumbelRV)
201200
@numba_funcify.register(aer.ExponentialRV)

pytensor/tensor/random/basic.py

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -487,56 +487,37 @@ def gamma(shape, rate=None, scale=None, **kwargs):
487487
return _gamma(shape, scale, **kwargs)
488488

489489

490-
class ChiSquareRV(RandomVariable):
491-
r"""A chi square continuous random variable.
490+
def chisquare(df, size=None, **kwargs):
491+
r"""Draw samples from a chisquare distribution.
492492
493493
The probability density function for `chisquare` in terms of the number of degrees of
494494
freedom :math:`k` is:
495495
496496
.. math::
497-
498497
f(x; k) = \frac{(1/2)^{k/2}}{\Gamma(k/2)} x^{k/2-1} e^{-x/2}
499-
500498
for :math:`k > 2`. :math:`\Gamma` is the gamma function:
501499
502500
.. math::
503-
504501
\Gamma(x) = \int_0^{\infty} t^{x-1} e^{-t} \mathrm{d}t
505502
506-
507503
This variable is obtained by summing the squares :math:`k` independent, standard normally
508504
distributed random variables.
509505
510-
"""
511-
name = "chisquare"
512-
ndim_supp = 0
513-
ndims_params = [0]
514-
dtype = "floatX"
515-
_print_name = ("ChiSquare", "\\operatorname{ChiSquare}")
516-
517-
def __call__(self, df, size=None, **kwargs):
518-
r"""Draw samples from a chisquare distribution.
519-
520-
Signature
521-
---------
522-
523-
`() -> ()`
524-
525-
Parameters
526-
----------
527-
df
528-
The number :math:`k` of degrees of freedom. Must be positive.
529-
size
530-
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
531-
independent, identically distributed random variables are
532-
returned. Default is `None` in which case a single random variable
533-
is returned.
534-
535-
"""
536-
return super().__call__(df, size=size, **kwargs)
537-
506+
Signature
507+
---------
508+
`() -> ()`
538509
539-
chisquare = ChiSquareRV()
510+
Parameters
511+
----------
512+
df
513+
The number :math:`k` of degrees of freedom. Must be positive.
514+
size
515+
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
516+
independent, identically distributed random variables are
517+
returned. Default is `None` in which case a single random variable
518+
is returned.
519+
"""
520+
return gamma(shape=df / 2.0, scale=2.0, size=size, **kwargs)
540521

541522

542523
class ParetoRV(ScipyRandomVariable):

pytensor/tensor/random/rewriting/jax.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pytensor.tensor.elemwise import DimShuffle
88
from pytensor.tensor.random.basic import (
99
BetaBinomialRV,
10-
ChiSquareRV,
1110
GenGammaRV,
1211
GeometricRV,
1312
HalfNormalRV,
@@ -104,13 +103,6 @@ def inverse_gamma_from_gamma(fgraph, node):
104103
return [next_rng, reciprocal(g)]
105104

106105

107-
@node_rewriter([ChiSquareRV])
108-
def chi_square_from_gamma(fgraph, node):
109-
*other_inputs, df = node.inputs
110-
next_rng, g = _gamma.make_node(*other_inputs, df / 2, 2).outputs
111-
return [next_rng, g]
112-
113-
114106
@node_rewriter([GenGammaRV])
115107
def generalized_gamma_from_gamma(fgraph, node):
116108
*other_inputs, alpha, p, lambd = node.inputs
@@ -171,11 +163,6 @@ def beta_binomial_from_beta_binomial(fgraph, node):
171163
in2out(inverse_gamma_from_gamma),
172164
"jax",
173165
)
174-
random_vars_opt.register(
175-
"chi_square_from_gamma",
176-
in2out(chi_square_from_gamma),
177-
"jax",
178-
)
179166
random_vars_opt.register(
180167
"generalized_gamma_from_gamma",
181168
in2out(generalized_gamma_from_gamma),

0 commit comments

Comments
 (0)