2525# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
2626# POSSIBILITY OF SUCH DAMAGE.
2727
28- from functools import total_ordering
29- import warnings
28+ from multipledispatch import dispatch
3029
3130from jax import lax
3231import jax .numpy as jnp
4241)
4342from numpyro .distributions .util import scale_and_mask , sum_rightmost
4443
45- _KL_REGISTRY = (
46- {}
47- ) # Source of truth mapping a few general (type, type) pairs to functions.
48- _KL_MEMOIZE = (
49- {}
50- ) # Memoized version mapping many specific (type, type) pairs to functions.
51-
52-
53- def register_kl (type_p , type_q ):
54- if not isinstance (type_p , type ) and issubclass (type_p , Distribution ):
55- raise TypeError (
56- "Expected type_p to be a Distribution subclass but got {}" .format (type_p )
57- )
58- if not isinstance (type_q , type ) and issubclass (type_q , Distribution ):
59- raise TypeError (
60- "Expected type_q to be a Distribution subclass but got {}" .format (type_q )
61- )
62-
63- def decorator (fun ):
64- _KL_REGISTRY [type_p , type_q ] = fun
65- _KL_MEMOIZE .clear () # reset since lookup order may have changed
66- return fun
67-
68- return decorator
69-
70-
71- @total_ordering
72- class _Match (object ):
73- __slots__ = ["types" ]
74-
75- def __init__ (self , * types ):
76- self .types = types
77-
78- def __eq__ (self , other ):
79- return self .types == other .types
80-
81- def __le__ (self , other ):
82- for x , y in zip (self .types , other .types ):
83- if not issubclass (x , y ):
84- return False
85- if x is not y :
86- break
87- return True
88-
89-
90- def _dispatch_kl (type_p , type_q ):
91- """
92- Find the most specific approximate match, assuming single inheritance.
93- """
94- matches = [
95- (super_p , super_q )
96- for super_p , super_q in _KL_REGISTRY
97- if issubclass (type_p , super_p ) and issubclass (type_q , super_q )
98- ]
99- if not matches :
100- return NotImplemented
101- # Check that the left- and right- lexicographic orders agree.
102- left_p , left_q = min (_Match (* m ) for m in matches ).types
103- right_q , right_p = min (_Match (* reversed (m )) for m in matches ).types
104- left_fun = _KL_REGISTRY [left_p , left_q ]
105- right_fun = _KL_REGISTRY [right_p , right_q ]
106- if left_fun is not right_fun :
107- warnings .warn (
108- "Ambiguous kl_divergence({}, {}). Please register_kl({}, {})" .format (
109- type_p .__name__ , type_q .__name__ , left_p .__name__ , right_q .__name__
110- ),
111- RuntimeWarning ,
112- )
113- return left_fun
114-
11544
11645def kl_divergence (p , q ):
11746 r"""
11847 Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
11948 """
120- try :
121- fun = _KL_MEMOIZE [type (p ), type (q )]
122- except KeyError :
123- fun = _dispatch_kl (type (p ), type (q ))
124- _KL_MEMOIZE [type (p ), type (q )] = fun
125- if fun is NotImplemented :
126- raise NotImplementedError
127- return fun (p , q )
49+ raise NotImplementedError
12850
12951
13052################################################################################
13153# KL Divergence Implementations
13254################################################################################
13355
13456
135- @register_kl (Distribution , ExpandedDistribution )
136- def _kl_dist_expanded (p , q ):
57+ @dispatch (Distribution , ExpandedDistribution )
58+ def kl_divergence (p , q ):
13759 kl = kl_divergence (p , q .base_dist )
13860 shape = lax .broadcast_shapes (p .batch_shape , q .batch_shape )
13961 return jnp .broadcast_to (kl , shape )
14062
14163
142- @register_kl (ExpandedDistribution , Distribution )
143- def _kl_expanded (p , q ):
64+ @dispatch (ExpandedDistribution , Distribution )
65+ def kl_divergence (p , q ):
14466 kl = kl_divergence (p .base_dist , q )
14567 shape = lax .broadcast_shapes (p .batch_shape , q .batch_shape )
14668 return jnp .broadcast_to (kl , shape )
14769
14870
149- @register_kl (ExpandedDistribution , ExpandedDistribution )
150- def _kl_expanded_expanded (p , q ):
71+ @dispatch (ExpandedDistribution , ExpandedDistribution )
72+ def kl_divergence (p , q ):
15173 kl = kl_divergence (p .base_dist , q .base_dist )
15274 shape = lax .broadcast_shapes (p .batch_shape , q .batch_shape )
15375 return jnp .broadcast_to (kl , shape )
15476
15577
156- @register_kl (Delta , Distribution )
157- def _kl_delta (p , q ):
78+ @dispatch (Delta , Distribution )
79+ def kl_divergence (p , q ):
15880 return - q .log_prob (p .v )
15981
16082
161- @register_kl (Independent , Independent )
162- def _kl_independent_independent (p , q ):
83+ @dispatch (Delta , ExpandedDistribution )
84+ def kl_divergence (p , q ):
85+ return - q .log_prob (p .v )
86+
87+
88+ @dispatch (Independent , Independent )
89+ def kl_divergence (p , q ):
16390 shared_ndims = min (p .reinterpreted_batch_ndims , q .reinterpreted_batch_ndims )
16491 p_ndims = p .reinterpreted_batch_ndims - shared_ndims
16592 q_ndims = q .reinterpreted_batch_ndims - shared_ndims
@@ -171,8 +98,8 @@ def _kl_independent_independent(p, q):
17198 return kl
17299
173100
174- @register_kl (MaskedDistribution , MaskedDistribution )
175- def _kl_masked_masked (p , q ):
101+ @dispatch (MaskedDistribution , MaskedDistribution )
102+ def kl_divergence (p , q ):
176103 if p ._mask is False or q ._mask is False :
177104 mask = False
178105 elif p ._mask is True :
@@ -192,15 +119,15 @@ def _kl_masked_masked(p, q):
192119 return scale_and_mask (kl , mask = mask )
193120
194121
195- @register_kl (Normal , Normal )
196- def _kl_normal_normal (p , q ):
122+ @dispatch (Normal , Normal )
123+ def kl_divergence (p , q ):
197124 var_ratio = jnp .square (p .scale / q .scale )
198125 t1 = jnp .square ((p .loc - q .loc ) / q .scale )
199126 return 0.5 * (var_ratio + t1 - 1 - jnp .log (var_ratio ))
200127
201128
202- @register_kl (Dirichlet , Dirichlet )
203- def _kl_dirichlet_dirichlet (p , q ):
129+ @dispatch (Dirichlet , Dirichlet )
130+ def kl_divergence (p , q ):
204131 # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
205132 sum_p_concentration = p .concentration .sum (- 1 )
206133 sum_q_concentration = q .concentration .sum (- 1 )
0 commit comments