Skip to content
Open
62 changes: 32 additions & 30 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def log_prob(self, value):

d = value.dim()
batch_dims = range(d - len(self.event_shape))
v = self._struct().score(
self.log_potentials,
value.type_as(self.log_potentials),
batch_dims=batch_dims,
)
v = self._struct().score(self.log_potentials, value.type_as(self.log_potentials), batch_dims=batch_dims,)

return v - self.partition

Expand All @@ -91,9 +87,7 @@ def cross_entropy(self, other):
cross entropy (*batch_shape*)
"""

return self._struct(CrossEntropySemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
return self._struct(CrossEntropySemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

def kl(self, other):
"""
Expand All @@ -105,9 +99,7 @@ def kl(self, other):
Returns:
cross entropy (*batch_shape*)
"""
return self._struct(KLDivergenceSemiring).sum(
[self.log_potentials, other.log_potentials], self.lengths
)
return self._struct(KLDivergenceSemiring).sum([self.log_potentials, other.log_potentials], self.lengths)

@lazy_property
def max(self):
Expand Down Expand Up @@ -140,9 +132,7 @@ def kmax(self, k):
kmax (*k x batch_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).sum(
self.log_potentials, self.lengths, _raw=True
)
return self._struct(KMaxSemiring(k)).sum(self.log_potentials, self.lengths, _raw=True)

def topk(self, k):
r"""
Expand All @@ -155,9 +145,7 @@ def topk(self, k):
kmax (*k x batch_shape x event_shape*)
"""
with torch.enable_grad():
return self._struct(KMaxSemiring(k)).marginals(
self.log_potentials, self.lengths, _raw=True
)
return self._struct(KMaxSemiring(k)).marginals(self.log_potentials, self.lengths, _raw=True)

@lazy_property
def mode(self):
Expand Down Expand Up @@ -186,9 +174,7 @@ def count(self):

def gumbel_crf(self, temperature=1.0):
with torch.enable_grad():
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(
self.log_potentials, self.lengths
)
st_gumbel = self._struct(GumbelCRFSemiring(temperature)).marginals(self.log_potentials, self.lengths)
return st_gumbel

# @constraints.dependent_property
Expand Down Expand Up @@ -219,9 +205,7 @@ def sample(self, sample_shape=torch.Size()):
samples = []
for k in range(nsamples):
if k % 10 == 0:
sample = self._struct(MultiSampledSemiring).marginals(
self.log_potentials, lengths=self.lengths
)
sample = self._struct(MultiSampledSemiring).marginals(self.log_potentials, lengths=self.lengths)
sample = sample.detach()
tmp_sample = MultiSampledSemiring.to_discrete(sample, (k % 10) + 1)
samples.append(tmp_sample)
Expand Down Expand Up @@ -301,9 +285,7 @@ def __init__(self, log_potentials, local=False, lengths=None, max_gap=None):
super().__init__(log_potentials, lengths)

def _struct(self, sr=None):
return self.struct(
sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap
)
return self.struct(sr if sr is not None else LogSemiring, self.local, max_gap=self.max_gap)


class HMM(StructDistribution):
Expand Down Expand Up @@ -440,9 +422,7 @@ def __init__(self, log_potentials, lengths=None):
event_shape = log_potentials[0].shape[1:]
self.log_potentials = log_potentials
self.lengths = lengths
super(StructDistribution, self).__init__(
batch_shape=batch_shape, event_shape=event_shape
)
super(StructDistribution, self).__init__(batch_shape=batch_shape, event_shape=event_shape)


class NonProjectiveDependencyCRF(StructDistribution):
Expand Down Expand Up @@ -504,4 +484,26 @@ def argmax(self):

@lazy_property
def entropy(self):
pass
r"""
Compute entropy efficiently using arc-factorization property.

Algorithm derivation:
..math::
{{
\begin{align}
H[p] &= E_{p(T)}[-\log p(T)]\\
&= -E_{p(T)}\big[ \log [\frac{1}{Z} \prod\limits_{(i,j) \in T} \exp\{\phi_{i,j}\}] \big]\\
&= -E_{p(T)}\big[ \sum\limits_{(i,j) \in T} \phi_{i,j} - \log Z \big]\\
&= \log Z -E_{p(T)}\big[\sum\limits_{(i,j) \in A} 1\{(i,j) \in T\} \phi_{i,j}\big]\\
&= \log Z - \sum\limits_{(i,j) \in A} p\big((i,j) \in T\big) \phi_{i,j}
\end{align}
}}

Returns:
entropy (*batch_shape)
"""
logZ = self.partition
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justinchiu points out that Instead of grouping this with non-proj we should just have it be the default entropy function for all of the models to materialize the marginals. (nothing non-projective specific).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a general property of all exponential family / sum-product models. See also this Twitter discussion https://twitter.com/RanZmigrod/status/1300832956701970434?s=20

p = self.marginals
phi = self.log_potentials
H = logZ - (p * phi).reshape(phi.shape[0], -1).sum(-1)
return H