Skip to content

Commit 545db3f

Browse files
authored
Use multipledispatch for kl_registry (#1252)
1 parent 7c8dd4e commit 545db3f

File tree

4 files changed

+30
-97
lines changed

4 files changed

+30
-97
lines changed

docs/requirements.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@ ipython
55
jax>=0.2.11
66
jaxlib>=0.1.62
77
jaxns>=0.0.7
8-
optax>=0.0.6
8+
multipledispatch
99
nbsphinx==0.8.6
10+
numpy
11+
optax>=0.0.6
1012
readthedocs-sphinx-search==0.1.0
1113
sphinx==4.0.3
1214
sphinx-gallery

numpyro/distributions/kl.py

Lines changed: 23 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
2626
# POSSIBILITY OF SUCH DAMAGE.
2727

28-
from functools import total_ordering
29-
import warnings
28+
from multipledispatch import dispatch
3029

3130
from jax import lax
3231
import jax.numpy as jnp
@@ -42,124 +41,52 @@
4241
)
4342
from numpyro.distributions.util import scale_and_mask, sum_rightmost
4443

45-
_KL_REGISTRY = (
46-
{}
47-
) # Source of truth mapping a few general (type, type) pairs to functions.
48-
_KL_MEMOIZE = (
49-
{}
50-
) # Memoized version mapping many specific (type, type) pairs to functions.
51-
52-
53-
def register_kl(type_p, type_q):
54-
if not isinstance(type_p, type) and issubclass(type_p, Distribution):
55-
raise TypeError(
56-
"Expected type_p to be a Distribution subclass but got {}".format(type_p)
57-
)
58-
if not isinstance(type_q, type) and issubclass(type_q, Distribution):
59-
raise TypeError(
60-
"Expected type_q to be a Distribution subclass but got {}".format(type_q)
61-
)
62-
63-
def decorator(fun):
64-
_KL_REGISTRY[type_p, type_q] = fun
65-
_KL_MEMOIZE.clear() # reset since lookup order may have changed
66-
return fun
67-
68-
return decorator
69-
70-
71-
@total_ordering
72-
class _Match(object):
73-
__slots__ = ["types"]
74-
75-
def __init__(self, *types):
76-
self.types = types
77-
78-
def __eq__(self, other):
79-
return self.types == other.types
80-
81-
def __le__(self, other):
82-
for x, y in zip(self.types, other.types):
83-
if not issubclass(x, y):
84-
return False
85-
if x is not y:
86-
break
87-
return True
88-
89-
90-
def _dispatch_kl(type_p, type_q):
91-
"""
92-
Find the most specific approximate match, assuming single inheritance.
93-
"""
94-
matches = [
95-
(super_p, super_q)
96-
for super_p, super_q in _KL_REGISTRY
97-
if issubclass(type_p, super_p) and issubclass(type_q, super_q)
98-
]
99-
if not matches:
100-
return NotImplemented
101-
# Check that the left- and right- lexicographic orders agree.
102-
left_p, left_q = min(_Match(*m) for m in matches).types
103-
right_q, right_p = min(_Match(*reversed(m)) for m in matches).types
104-
left_fun = _KL_REGISTRY[left_p, left_q]
105-
right_fun = _KL_REGISTRY[right_p, right_q]
106-
if left_fun is not right_fun:
107-
warnings.warn(
108-
"Ambiguous kl_divergence({}, {}). Please register_kl({}, {})".format(
109-
type_p.__name__, type_q.__name__, left_p.__name__, right_q.__name__
110-
),
111-
RuntimeWarning,
112-
)
113-
return left_fun
114-
11544

11645
def kl_divergence(p, q):
11746
r"""
11847
Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions.
11948
"""
120-
try:
121-
fun = _KL_MEMOIZE[type(p), type(q)]
122-
except KeyError:
123-
fun = _dispatch_kl(type(p), type(q))
124-
_KL_MEMOIZE[type(p), type(q)] = fun
125-
if fun is NotImplemented:
126-
raise NotImplementedError
127-
return fun(p, q)
49+
raise NotImplementedError
12850

12951

13052
################################################################################
13153
# KL Divergence Implementations
13254
################################################################################
13355

13456

135-
@register_kl(Distribution, ExpandedDistribution)
136-
def _kl_dist_expanded(p, q):
57+
@dispatch(Distribution, ExpandedDistribution)
58+
def kl_divergence(p, q):
13759
kl = kl_divergence(p, q.base_dist)
13860
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
13961
return jnp.broadcast_to(kl, shape)
14062

14163

142-
@register_kl(ExpandedDistribution, Distribution)
143-
def _kl_expanded(p, q):
64+
@dispatch(ExpandedDistribution, Distribution)
65+
def kl_divergence(p, q):
14466
kl = kl_divergence(p.base_dist, q)
14567
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
14668
return jnp.broadcast_to(kl, shape)
14769

14870

149-
@register_kl(ExpandedDistribution, ExpandedDistribution)
150-
def _kl_expanded_expanded(p, q):
71+
@dispatch(ExpandedDistribution, ExpandedDistribution)
72+
def kl_divergence(p, q):
15173
kl = kl_divergence(p.base_dist, q.base_dist)
15274
shape = lax.broadcast_shapes(p.batch_shape, q.batch_shape)
15375
return jnp.broadcast_to(kl, shape)
15476

15577

156-
@register_kl(Delta, Distribution)
157-
def _kl_delta(p, q):
78+
@dispatch(Delta, Distribution)
79+
def kl_divergence(p, q):
15880
return -q.log_prob(p.v)
15981

16082

161-
@register_kl(Independent, Independent)
162-
def _kl_independent_independent(p, q):
83+
@dispatch(Delta, ExpandedDistribution)
84+
def kl_divergence(p, q):
85+
return -q.log_prob(p.v)
86+
87+
88+
@dispatch(Independent, Independent)
89+
def kl_divergence(p, q):
16390
shared_ndims = min(p.reinterpreted_batch_ndims, q.reinterpreted_batch_ndims)
16491
p_ndims = p.reinterpreted_batch_ndims - shared_ndims
16592
q_ndims = q.reinterpreted_batch_ndims - shared_ndims
@@ -171,8 +98,8 @@ def _kl_independent_independent(p, q):
17198
return kl
17299

173100

174-
@register_kl(MaskedDistribution, MaskedDistribution)
175-
def _kl_masked_masked(p, q):
101+
@dispatch(MaskedDistribution, MaskedDistribution)
102+
def kl_divergence(p, q):
176103
if p._mask is False or q._mask is False:
177104
mask = False
178105
elif p._mask is True:
@@ -192,15 +119,15 @@ def _kl_masked_masked(p, q):
192119
return scale_and_mask(kl, mask=mask)
193120

194121

195-
@register_kl(Normal, Normal)
196-
def _kl_normal_normal(p, q):
122+
@dispatch(Normal, Normal)
123+
def kl_divergence(p, q):
197124
var_ratio = jnp.square(p.scale / q.scale)
198125
t1 = jnp.square((p.loc - q.loc) / q.scale)
199126
return 0.5 * (var_ratio + t1 - 1 - jnp.log(var_ratio))
200127

201128

202-
@register_kl(Dirichlet, Dirichlet)
203-
def _kl_dirichlet_dirichlet(p, q):
129+
@dispatch(Dirichlet, Dirichlet)
130+
def kl_divergence(p, q):
204131
# From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
205132
sum_p_concentration = p.concentration.sum(-1)
206133
sum_q_concentration = q.concentration.sum(-1)

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
max-line-length = 120
33
exclude = docs/src, build, dist, .ipynb_checkpoints
44
ignore = W503,E203
5+
per-file-ignores =
6+
numpyro/distributions/kl.py:F811
57

68
[isort]
79
profile = black

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
install_requires=[
3636
f"jax{_jax_version_constraints}",
3737
f"jaxlib{_jaxlib_version_constraints}",
38+
"multipledispatch",
39+
"numpy",
3840
"tqdm",
3941
],
4042
extras_require={

0 commit comments

Comments
 (0)