21
21
from functools import singledispatch
22
22
from typing import Callable , Optional , Sequence , Tuple , Union
23
23
24
- import aesara
25
24
import numpy as np
26
25
27
26
from aeppl .abstract import MeasurableVariable , _get_measurable_outputs
@@ -492,12 +491,124 @@ class DensityDist(Distribution):
492
491
Creates a Distribution and registers the supplied log density function to be used
493
492
for inference. It is also possible to supply a `random` method in order to be able
494
493
to sample from the prior or posterior predictive distributions.
494
+
495
+
496
+ Parameters
497
+ ----------
498
+ name : str
499
+ dist_params : Tuple
500
+ A sequence of the distribution's parameter. These will be converted into
501
+ Aesara tensors internally. These parameters could be other ``TensorVariable``
502
+ instances created from , optionally created via ``RandomVariable`` ``Op``s.
503
+ class_name : str
504
+ Name for the RandomVariable class which will wrap the DensityDist methods.
505
+ When not specified, it will be given the name of the variable.
506
+
507
+ .. warning:: New DensityDists created with the same class_name will override the
508
+ methods dispatched onto the previous classes. If using DensityDists with
509
+ different methods across separate models, be sure to use distinct
510
+ class_names.
511
+
512
+ logp : Optional[Callable]
513
+ A callable that calculates the log density of some given observed ``value``
514
+ conditioned on certain distribution parameter values. It must have the
515
+ following signature: ``logp(value, *dist_params)``, where ``value`` is
516
+ an Aesara tensor that represents the observed value, and ``dist_params``
517
+ are the tensors that hold the values of the distribution parameters.
518
+ This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
519
+ error will be raised when trying to compute the distribution's logp.
520
+ logcdf : Optional[Callable]
521
+ A callable that calculates the log cummulative probability of some given observed
522
+ ``value`` conditioned on certain distribution parameter values. It must have the
523
+ following signature: ``logcdf(value, *dist_params)``, where ``value`` is
524
+ an Aesara tensor that represents the observed value, and ``dist_params``
525
+ are the tensors that hold the values of the distribution parameters.
526
+ This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
527
+ error will be raised when trying to compute the distribution's logcdf.
528
+ random : Optional[Callable]
529
+ A callable that can be used to generate random draws from the distribution.
530
+ It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
531
+ The distribution parameters are passed as positional arguments in the
532
+ same order as they are supplied when the ``DensityDist`` is constructed.
533
+ The keyword arguments are ``rnd``, which will provide the random variable's
534
+ associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
535
+ the desired size of the random draw. If ``None``, a ``NotImplemented``
536
+ error will be raised when trying to draw random samples from the distribution's
537
+ prior or posterior predictive.
538
+ moment : Optional[Callable]
539
+ A callable that can be used to compute the moments of the distribution.
540
+ It must have the following signature: ``moment(rv, size, *rv_inputs)``.
541
+ The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
542
+ as the first argument ``rv``. ``size`` is the random variable's size implied
543
+ by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
544
+ ``rv_inputs`` is the sequence of the distribution parameters, in the same order
545
+ as they were supplied when the DensityDist was created. If ``None``, a default
546
+ ``moment`` function will be assigned that will always return 0, or an array
547
+ of zeros.
548
+ ndim_supp : int
549
+ The number of dimensions in the support of the distribution. Defaults to assuming
550
+ a scalar distribution, i.e. ``ndim_supp = 0``.
551
+ ndims_params : Optional[Sequence[int]]
552
+ The list of number of dimensions in the support of each of the distribution's
553
+ parameters. If ``None``, it is assumed that all parameters are scalars, hence
554
+ the number of dimensions of their support will be 0.
555
+ dtype : str
556
+ The dtype of the distribution. All draws and observations passed into the distribution
557
+ will be casted onto this dtype.
558
+ kwargs :
559
+ Extra keyword arguments are passed to the parent's class ``__new__`` method.
560
+
561
+ Examples
562
+ --------
563
+ .. code-block:: python
564
+
565
+ def logp(value, mu):
566
+ return -(value - mu)**2
567
+
568
+ with pm.Model():
569
+ mu = pm.Normal('mu',0,1)
570
+ pm.DensityDist(
571
+ 'density_dist',
572
+ mu,
573
+ logp=logp,
574
+ observed=np.random.randn(100),
575
+ )
576
+ idata = pm.sample(100)
577
+
578
+ .. code-block:: python
579
+
580
+ def logp(value, mu):
581
+ return -(value - mu)**2
582
+
583
+ def random(mu, rng=None, size=None):
584
+ return rng.normal(loc=mu, scale=1, size=size)
585
+
586
+ with pm.Model():
587
+ mu = pm.Normal('mu', 0 , 1)
588
+ dens = pm.DensityDist(
589
+ 'density_dist',
590
+ mu,
591
+ logp=logp,
592
+ random=random,
593
+ observed=np.random.randn(100, 3),
594
+ size=(100, 3),
595
+ )
596
+ prior = pm.sample_prior_predictive(10).prior_predictive['density_dist']
597
+ assert prior.shape == (1, 10, 100, 3)
598
+
495
599
"""
496
600
497
- def __new__ (
601
+ rv_type = DensityDistRV
602
+
603
+ def __new__ (cls , name , * args , ** kwargs ):
604
+ kwargs .setdefault ("class_name" , name )
605
+ return super ().__new__ (cls , name , * args , ** kwargs )
606
+
607
+ @classmethod
608
+ def dist (
498
609
cls ,
499
- name : str ,
500
610
* dist_params ,
611
+ class_name : str ,
501
612
logp : Optional [Callable ] = None ,
502
613
logcdf : Optional [Callable ] = None ,
503
614
random : Optional [Callable ] = None ,
@@ -507,102 +618,6 @@ def __new__(
507
618
dtype : str = "floatX" ,
508
619
** kwargs ,
509
620
):
510
- """
511
- Parameters
512
- ----------
513
- name : str
514
- dist_params : Tuple
515
- A sequence of the distribution's parameter. These will be converted into
516
- Aesara tensors internally. These parameters could be other ``TensorVariable``
517
- instances created from , optionally created via ``RandomVariable`` ``Op``s.
518
- logp : Optional[Callable]
519
- A callable that calculates the log density of some given observed ``value``
520
- conditioned on certain distribution parameter values. It must have the
521
- following signature: ``logp(value, *dist_params)``, where ``value`` is
522
- an Aesara tensor that represents the observed value, and ``dist_params``
523
- are the tensors that hold the values of the distribution parameters.
524
- This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
525
- error will be raised when trying to compute the distribution's logp.
526
- logcdf : Optional[Callable]
527
- A callable that calculates the log cummulative probability of some given observed
528
- ``value`` conditioned on certain distribution parameter values. It must have the
529
- following signature: ``logcdf(value, *dist_params)``, where ``value`` is
530
- an Aesara tensor that represents the observed value, and ``dist_params``
531
- are the tensors that hold the values of the distribution parameters.
532
- This function must return an Aesara tensor. If ``None``, a ``NotImplemented``
533
- error will be raised when trying to compute the distribution's logcdf.
534
- random : Optional[Callable]
535
- A callable that can be used to generate random draws from the distribution.
536
- It must have the following signature: ``random(*dist_params, rng=None, size=None)``.
537
- The distribution parameters are passed as positional arguments in the
538
- same order as they are supplied when the ``DensityDist`` is constructed.
539
- The keyword arguments are ``rnd``, which will provide the random variable's
540
- associated :py:class:`~numpy.random.Generator`, and ``size``, that will represent
541
- the desired size of the random draw. If ``None``, a ``NotImplemented``
542
- error will be raised when trying to draw random samples from the distribution's
543
- prior or posterior predictive.
544
- moment : Optional[Callable]
545
- A callable that can be used to compute the moments of the distribution.
546
- It must have the following signature: ``moment(rv, size, *rv_inputs)``.
547
- The distribution's :class:`~aesara.tensor.random.op.RandomVariable` is passed
548
- as the first argument ``rv``. ``size`` is the random variable's size implied
549
- by the ``dims``, ``size`` and parameters supplied to the distribution. Finally,
550
- ``rv_inputs`` is the sequence of the distribution parameters, in the same order
551
- as they were supplied when the DensityDist was created. If ``None``, a default
552
- ``moment`` function will be assigned that will always return 0, or an array
553
- of zeros.
554
- ndim_supp : int
555
- The number of dimensions in the support of the distribution. Defaults to assuming
556
- a scalar distribution, i.e. ``ndim_supp = 0``.
557
- ndims_params : Optional[Sequence[int]]
558
- The list of number of dimensions in the support of each of the distribution's
559
- parameters. If ``None``, it is assumed that all parameters are scalars, hence
560
- the number of dimensions of their support will be 0.
561
- dtype : str
562
- The dtype of the distribution. All draws and observations passed into the distribution
563
- will be casted onto this dtype.
564
- kwargs :
565
- Extra keyword arguments are passed to the parent's class ``__new__`` method.
566
-
567
- Examples
568
- --------
569
- .. code-block:: python
570
-
571
- def logp(value, mu):
572
- return -(value - mu)**2
573
-
574
- with pm.Model():
575
- mu = pm.Normal('mu',0,1)
576
- pm.DensityDist(
577
- 'density_dist',
578
- mu,
579
- logp=logp,
580
- observed=np.random.randn(100),
581
- )
582
- idata = pm.sample(100)
583
-
584
- .. code-block:: python
585
-
586
- def logp(value, mu):
587
- return -(value - mu)**2
588
-
589
- def random(mu, rng=None, size=None):
590
- return rng.normal(loc=mu, scale=1, size=size)
591
-
592
- with pm.Model():
593
- mu = pm.Normal('mu', 0 , 1)
594
- dens = pm.DensityDist(
595
- 'density_dist',
596
- mu,
597
- logp=logp,
598
- random=random,
599
- observed=np.random.randn(100, 3),
600
- size=(100, 3),
601
- )
602
- prior = pm.sample_prior_predictive(10).prior_predictive['density_dist']
603
- assert prior.shape == (1, 10, 100, 3)
604
-
605
- """
606
621
607
622
if dist_params is None :
608
623
dist_params = []
@@ -614,34 +629,61 @@ def random(mu, rng=None, size=None):
614
629
"to the API documentation for more information on how to use the "
615
630
"new DensityDist API."
616
631
)
617
- dist_params = [as_tensor_variable (param ) for param in dist_params ]
632
+ dist_params = [as_tensor_variable (param ) for param in dist_params ]
618
633
619
634
# Assume scalar ndims_params
620
635
if ndims_params is None :
621
636
ndims_params = [0 ] * len (dist_params )
622
637
623
638
if logp is None :
624
- logp = default_not_implemented (name , "logp" )
639
+ logp = default_not_implemented (class_name , "logp" )
625
640
626
641
if logcdf is None :
627
- logcdf = default_not_implemented (name , "logcdf" )
642
+ logcdf = default_not_implemented (class_name , "logcdf" )
628
643
629
644
if moment is None :
630
645
moment = functools .partial (
631
646
default_moment ,
632
- rv_name = name ,
647
+ rv_name = class_name ,
633
648
has_fallback = random is not None ,
634
649
ndim_supp = ndim_supp ,
635
650
)
636
651
637
652
if random is None :
638
- random = default_not_implemented (name , "random" )
653
+ random = default_not_implemented (class_name , "random" )
654
+
655
+ return super ().dist (
656
+ dist_params ,
657
+ class_name = class_name ,
658
+ logp = logp ,
659
+ logcdf = logcdf ,
660
+ random = random ,
661
+ moment = moment ,
662
+ ndim_supp = ndim_supp ,
663
+ ndims_params = ndims_params ,
664
+ dtype = dtype ,
665
+ ** kwargs ,
666
+ )
639
667
668
+ @classmethod
669
+ def rv_op (
670
+ cls ,
671
+ * dist_params ,
672
+ class_name : str ,
673
+ logp : Optional [Callable ],
674
+ logcdf : Optional [Callable ],
675
+ random : Optional [Callable ],
676
+ moment : Optional [Callable ],
677
+ ndim_supp : int ,
678
+ ndims_params : Optional [Sequence [int ]],
679
+ dtype : str ,
680
+ ** kwargs ,
681
+ ):
640
682
rv_op = type (
641
- f"DensityDist_{ name } " ,
683
+ f"DensityDist_{ class_name } " ,
642
684
(DensityDistRV ,),
643
685
dict (
644
- name = f"DensityDist_{ name } " ,
686
+ name = f"DensityDist_{ class_name } " ,
645
687
inplace = False ,
646
688
ndim_supp = ndim_supp ,
647
689
ndims_params = ndims_params ,
@@ -669,18 +711,7 @@ def density_dist_logcdf(op, var, rvs_to_values, *dist_params, **kwargs):
669
711
def density_dist_get_moment (op , rv , rng , size , dtype , * dist_params ):
670
712
return moment (rv , size , * dist_params )
671
713
672
- cls .rv_op = rv_op
673
- return super ().__new__ (cls , name , * dist_params , ** kwargs )
674
-
675
- @classmethod
676
- def dist (cls , * args , ** kwargs ):
677
- output = super ().dist (args , ** kwargs )
678
- if cls .rv_op .dtype == "floatX" :
679
- dtype = aesara .config .floatX
680
- else :
681
- dtype = cls .rv_op .dtype
682
- ndim_supp = cls .rv_op .ndim_supp
683
- return output
714
+ return rv_op (* dist_params , ** kwargs )
684
715
685
716
686
717
def default_not_implemented (rv_name , method_name ):
0 commit comments