Skip to content
Open
64 changes: 45 additions & 19 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
CrossEntropySemiring,
KLDivergenceSemiring,
MultiSampledSemiring,
KMaxSemiring,
StdSemiring,
Expand Down Expand Up @@ -72,14 +69,25 @@ def log_prob(self, value):

@lazy_property
def entropy(self):
"""
Compute entropy for distribution :math:`H[z]`.
r"""
Compute entropy for distribution :math:`H[p]`.

Algorithm derivation:
..math::
H[p] &= E_{p(z)}[-\log p(z)]\\
&= -E_{p(z)}\big[ \log [\frac{1}{Z} \prod\limits_{c \in \mathcal{C}} \exp\{\phi_c(z_c)\}] \big]\\
&= -E_{p(z)}\big[ \sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c) - \log Z \big]\\
&= \log Z -E_{p(z)}\big[\sum\limits_{c \in \mathcal{C}} \phi_{c}(z_c)\big]\\
&= \log Z - \sum\limits_{c \in \mathcal{C}} p(z_c) \phi_{c}(z_c)

Returns:
entropy (*batch_shape*)
"""

return self._struct(EntropySemiring).sum(self.log_potentials, self.lengths)
logZ = self.partition
p = self.marginals
phi = self.log_potentials
Hp = logZ - (p * phi).reshape(p.shape[0], -1).sum(-1)
return Hp

def cross_entropy(self, other):
"""
Expand All @@ -91,10 +99,11 @@ def cross_entropy(self, other):
Returns:
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
logZ = other.partition
p = self.marginals
phi_q = other.log_potentials
Hq = logZ - (p * phi_q).reshape(p.shape[0], -1).sum(-1)
return Hq

def kl(self, other):
"""
Expand All @@ -104,11 +113,15 @@ def kl(self, other):
other : Comparison distribution

Returns:
cross entropy (*batch_shape*)
kl divergence (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
logZp = self.partition
logZq = other.partition
p = self.marginals
phi_p = self.log_potentials
phi_q = other.log_potentials
KLpq = (p * (phi_p - phi_q)).reshape(p.shape[0], -1).sum(-1) - logZp + logZq
return KLpq

@lazy_property
def max(self):
Expand Down Expand Up @@ -472,6 +485,23 @@ def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
self.multiroot = multiroot

def log_prob(self, value):
"""
Compute log probability over values :math:`p(z)`.

Parameters:
value (tensor): One-hot events (*sample_shape x batch_shape x event_shape*)

Returns:
log_probs (*sample_shape x batch_shape*)
"""
s = value.shape
# assumes values do not have any 1s outside of the lengths
value_total_log_potentials = (
(value * self.log_potentials.expand(s)).reshape(*s[:-2], -1).sum(-1)
)
return value_total_log_potentials - self.partition

@lazy_property
def marginals(self):
"""
Expand Down Expand Up @@ -502,7 +532,3 @@ def argmax(self):
(Currently not implemented)
"""
pass

@lazy_property
def entropy(self):
pass