Skip to content

Commit 6804c96

Browse files
committed
Enable dist API for DensityDist
1 parent c6d565d commit 6804c96

File tree

2 files changed

+166
-118
lines changed

2 files changed

+166
-118
lines changed

pymc/distributions/distribution.py

Lines changed: 149 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from functools import singledispatch
2222
from typing import Callable, Optional, Sequence, Tuple, Union
2323

24-
import aesara
2524
import numpy as np
2625

2726
from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
@@ -492,12 +491,124 @@ class DensityDist(Distribution):
492491
Creates a Distribution and registers the supplied log density function to be used
493492
for inference. It is also possible to supply a `random` method in order to be able
494493
to sample from the prior or posterior predictive distributions.
494+
495+
496+
Parameters
497+
----------
498+
name : str
499+
dist_params : Tuple
500+
A sequence of the distribution's parameter. These will be converted into
501+
Aesara tensors internally. These parameters could be other ``TensorVariable``
502+
instances created from , optionally created via ``RandomVariable`` ``Op``s.
503+
class_name : str
504+
Name for the RandomVariable class which will wrap the DensityDist methods.
505+
When not specified, it will be given the name of the variable.
506+
507+
.. warning:: New DensityDists created with the same class_name will override the
508+
methods dispatched onto the previous classes. If using DensityDists with
509+
different methods across separate models, be sure to use distinct
510+
class_names.
511+
512+
logp : Optional[Callable]
513+
A callable that calculates the log density of some given observed ``value``
514+
conditioned on certain distribution parameter values. It must have the
515+
following signature: ``logp(value, *dist_params)``, where ``value`` is
516+
an Aesara tensor that represents the observed value, and ``dist_params``
517+
are the tensors that hold the values of the distribution parameters.
518+
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
519+
error will be raised when trying to compute the distribution's logp.
520+
logcdf : Optional[Callable]
521+
A callable that calculates the log cummulative probability of some given observed
522+
``value`` conditioned on certain distribution parameter values. It must have the
523+
following signature: ``logcdf(value, *dist_params)``, where ``value`` is
524+
an Aesara tensor that represents the observed value, and ``dist_params``
525+
are the tensors that hold the values of the distribution parameters.
526+
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
527+
error will be raised when trying to compute the distribution's logcdf.
528+
random : Optional[Callable]
529+
A callable that can be used to generate random draws from the distribution.
530+
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
531+
The distribution parameters are passed as positional arguments in the
532+
same order as they are supplied when the ``DensityDist`` is constructed.
533+
The keyword arguments are ``rnd``, which will provide the random variable's
534+
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
535+
the desired size of the random draw. If ``None``, a ``NotImplemented``
536+
error will be raised when trying to draw random samples from the distribution's
537+
prior or posterior predictive.
538+
moment : Optional[Callable]
539+
A callable that can be used to compute the moments of the distribution.
540+
It must have the following signature: ``moment(rv, size, *rv_inputs)``.
541+
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
542+
as the first argument ``rv``. ``size`` is the random variable's size implied
543+
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
544+
``rv_inputs`` is the sequence of the distribution parameters, in the same order
545+
as they were supplied when the DensityDist was created. If ``None``, a default
546+
``moment`` function will be assigned that will always return 0, or an array
547+
of zeros.
548+
ndim_supp : int
549+
The number of dimensions in the support of the distribution. Defaults to assuming
550+
a scalar distribution, i.e. ``ndim_supp = 0``.
551+
ndims_params : Optional[Sequence[int]]
552+
The list of number of dimensions in the support of each of the distribution's
553+
parameters. If ``None``, it is assumed that all parameters are scalars, hence
554+
the number of dimensions of their support will be 0.
555+
dtype : str
556+
The dtype of the distribution. All draws and observations passed into the distribution
557+
will be casted onto this dtype.
558+
kwargs :
559+
Extra keyword arguments are passed to the parent's class ``__new__`` method.
560+
561+
Examples
562+
--------
563+
.. code-block:: python
564+
565+
def logp(value, mu):
566+
return -(value - mu)**2
567+
568+
with pm.Model():
569+
mu = pm.Normal('mu',0,1)
570+
pm.DensityDist(
571+
'density_dist',
572+
mu,
573+
logp=logp,
574+
observed=np.random.randn(100),
575+
)
576+
idata = pm.sample(100)
577+
578+
.. code-block:: python
579+
580+
def logp(value, mu):
581+
return -(value - mu)**2
582+
583+
def random(mu, rng=None, size=None):
584+
return rng.normal(loc=mu, scale=1, size=size)
585+
586+
with pm.Model():
587+
mu = pm.Normal('mu', 0 , 1)
588+
dens = pm.DensityDist(
589+
'density_dist',
590+
mu,
591+
logp=logp,
592+
random=random,
593+
observed=np.random.randn(100, 3),
594+
size=(100, 3),
595+
)
596+
prior = pm.sample_prior_predictive(10).prior_predictive['density_dist']
597+
assert prior.shape == (1, 10, 100, 3)
598+
495599
"""
496600

497-
def __new__(
601+
rv_type = DensityDistRV
602+
603+
def __new__(cls, name, *args, **kwargs):
604+
kwargs.setdefault("class_name", name)
605+
return super().__new__(cls, name, *args, **kwargs)
606+
607+
@classmethod
608+
def dist(
498609
cls,
499-
name: str,
500610
*dist_params,
611+
class_name: str,
501612
logp: Optional[Callable] = None,
502613
logcdf: Optional[Callable] = None,
503614
random: Optional[Callable] = None,
@@ -507,102 +618,6 @@ def __new__(
507618
dtype: str = "floatX",
508619
**kwargs,
509620
):
510-
"""
511-
Parameters
512-
----------
513-
name : str
514-
dist_params : Tuple
515-
A sequence of the distribution's parameter. These will be converted into
516-
Aesara tensors internally. These parameters could be other ``TensorVariable``
517-
instances created from , optionally created via ``RandomVariable`` ``Op``s.
518-
logp : Optional[Callable]
519-
A callable that calculates the log density of some given observed ``value``
520-
conditioned on certain distribution parameter values. It must have the
521-
following signature: ``logp(value, *dist_params)``, where ``value`` is
522-
an Aesara tensor that represents the observed value, and ``dist_params``
523-
are the tensors that hold the values of the distribution parameters.
524-
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
525-
error will be raised when trying to compute the distribution's logp.
526-
logcdf : Optional[Callable]
527-
A callable that calculates the log cummulative probability of some given observed
528-
``value`` conditioned on certain distribution parameter values. It must have the
529-
following signature: ``logcdf(value, *dist_params)``, where ``value`` is
530-
an Aesara tensor that represents the observed value, and ``dist_params``
531-
are the tensors that hold the values of the distribution parameters.
532-
This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
533-
error will be raised when trying to compute the distribution's logcdf.
534-
random : Optional[Callable]
535-
A callable that can be used to generate random draws from the distribution.
536-
It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
537-
The distribution parameters are passed as positional arguments in the
538-
same order as they are supplied when the ``DensityDist`` is constructed.
539-
The keyword arguments are ``rnd``, which will provide the random variable's
540-
associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
541-
the desired size of the random draw. If ``None``, a ``NotImplemented``
542-
error will be raised when trying to draw random samples from the distribution's
543-
prior or posterior predictive.
544-
moment : Optional[Callable]
545-
A callable that can be used to compute the moments of the distribution.
546-
It must have the following signature: ``moment(rv, size, *rv_inputs)``.
547-
The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
548-
as the first argument ``rv``. ``size`` is the random variable's size implied
549-
by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
550-
``rv_inputs`` is the sequence of the distribution parameters, in the same order
551-
as they were supplied when the DensityDist was created. If ``None``, a default
552-
``moment`` function will be assigned that will always return 0, or an array
553-
of zeros.
554-
ndim_supp : int
555-
The number of dimensions in the support of the distribution. Defaults to assuming
556-
a scalar distribution, i.e. ``ndim_supp = 0``.
557-
ndims_params : Optional[Sequence[int]]
558-
The list of number of dimensions in the support of each of the distribution's
559-
parameters. If ``None``, it is assumed that all parameters are scalars, hence
560-
the number of dimensions of their support will be 0.
561-
dtype : str
562-
The dtype of the distribution. All draws and observations passed into the distribution
563-
will be casted onto this dtype.
564-
kwargs :
565-
Extra keyword arguments are passed to the parent's class ``__new__`` method.
566-
567-
Examples
568-
--------
569-
.. code-block:: python
570-
571-
def logp(value, mu):
572-
return -(value - mu)**2
573-
574-
with pm.Model():
575-
mu = pm.Normal('mu',0,1)
576-
pm.DensityDist(
577-
'density_dist',
578-
mu,
579-
logp=logp,
580-
observed=np.random.randn(100),
581-
)
582-
idata = pm.sample(100)
583-
584-
.. code-block:: python
585-
586-
def logp(value, mu):
587-
return -(value - mu)**2
588-
589-
def random(mu, rng=None, size=None):
590-
return rng.normal(loc=mu, scale=1, size=size)
591-
592-
with pm.Model():
593-
mu = pm.Normal('mu', 0 , 1)
594-
dens = pm.DensityDist(
595-
'density_dist',
596-
mu,
597-
logp=logp,
598-
random=random,
599-
observed=np.random.randn(100, 3),
600-
size=(100, 3),
601-
)
602-
prior = pm.sample_prior_predictive(10).prior_predictive['density_dist']
603-
assert prior.shape == (1, 10, 100, 3)
604-
605-
"""
606621

607622
if dist_params is None:
608623
dist_params = []
@@ -614,34 +629,61 @@ def random(mu, rng=None, size=None):
614629
"to the API documentation for more information on how to use the "
615630
"new DensityDist API."
616631
)
617-
dist_params = [as_tensor_variable(param) for param in dist_params]
632+
dist_params = [as_tensor_variable(param) for param in dist_params]
618633

619634
# Assume scalar ndims_params
620635
if ndims_params is None:
621636
ndims_params = [0] * len(dist_params)
622637

623638
if logp is None:
624-
logp = default_not_implemented(name, "logp")
639+
logp = default_not_implemented(class_name, "logp")
625640

626641
if logcdf is None:
627-
logcdf = default_not_implemented(name, "logcdf")
642+
logcdf = default_not_implemented(class_name, "logcdf")
628643

629644
if moment is None:
630645
moment = functools.partial(
631646
default_moment,
632-
rv_name=name,
647+
rv_name=class_name,
633648
has_fallback=random is not None,
634649
ndim_supp=ndim_supp,
635650
)
636651

637652
if random is None:
638-
random = default_not_implemented(name, "random")
653+
random = default_not_implemented(class_name, "random")
654+
655+
return super().dist(
656+
dist_params,
657+
class_name=class_name,
658+
logp=logp,
659+
logcdf=logcdf,
660+
random=random,
661+
moment=moment,
662+
ndim_supp=ndim_supp,
663+
ndims_params=ndims_params,
664+
dtype=dtype,
665+
**kwargs,
666+
)
639667

668+
@classmethod
669+
def rv_op(
670+
cls,
671+
*dist_params,
672+
class_name: str,
673+
logp: Optional[Callable],
674+
logcdf: Optional[Callable],
675+
random: Optional[Callable],
676+
moment: Optional[Callable],
677+
ndim_supp: int,
678+
ndims_params: Optional[Sequence[int]],
679+
dtype: str,
680+
**kwargs,
681+
):
640682
rv_op = type(
641-
f"DensityDist_{name}",
683+
f"DensityDist_{class_name}",
642684
(DensityDistRV,),
643685
dict(
644-
name=f"DensityDist_{name}",
686+
name=f"DensityDist_{class_name}",
645687
inplace=False,
646688
ndim_supp=ndim_supp,
647689
ndims_params=ndims_params,
@@ -669,18 +711,7 @@ def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
669711
def density_dist_get_moment(op, rv, rng, size, dtype, *dist_params):
670712
return moment(rv, size, *dist_params)
671713

672-
cls.rv_op = rv_op
673-
return super().__new__(cls, name, *dist_params, **kwargs)
674-
675-
@classmethod
676-
def dist(cls, *args, **kwargs):
677-
output = super().dist(args, **kwargs)
678-
if cls.rv_op.dtype == "floatX":
679-
dtype = aesara.config.floatX
680-
else:
681-
dtype = cls.rv_op.dtype
682-
ndim_supp = cls.rv_op.ndim_supp
683-
return output
714+
return rv_op(*dist_params, **kwargs)
684715

685716

686717
def default_not_implemented(rv_name, method_name):

pymc/tests/distributions/test_distribution.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy.random as npr
2020
import numpy.testing as npt
2121
import pytest
22+
import scipy.stats as st
2223

2324
from aeppl.abstract import get_measurable_outputs
2425
from aesara.tensor import TensorVariable
@@ -304,6 +305,22 @@ def _random(mu, rng=None, size=None):
304305
):
305306
evaled_moment = moment(a).eval({mu: mu_val})
306307

308+
def test_dist(self):
309+
mu = 1
310+
x = pm.DensityDist.dist(
311+
mu,
312+
class_name="test",
313+
logp=lambda value, mu: pm.logp(pm.Normal.dist(mu), value),
314+
random=lambda mu, rng=None, size=None: rng.normal(loc=mu, scale=1, size=size),
315+
shape=(3,),
316+
)
317+
318+
test_value = pm.draw(x, random_seed=1)
319+
assert np.all(test_value == pm.draw(x, random_seed=1))
320+
321+
x_logp = pm.logp(x, test_value)
322+
assert np.allclose(x_logp.eval(), st.norm(1).logpdf(test_value))
323+
307324

308325
class TestSymbolicRandomVarible:
309326
def test_inline(self):

0 commit comments

Comments
 (0)