Skip to content

Commit fcb23f9

Browse files
ColCarrolltwiecki
authored andcommitted
Latex model (#2450)
* Add __latex__ alias, latex support for models * Guard for missing latex * Add LaTeX for deterministics, test for more complicated model
1 parent 66199ba commit fcb23f9

File tree

3 files changed

+94
-26
lines changed

3 files changed

+94
-26
lines changed

pymc3/distributions/distribution.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,11 @@ def getattr_value(self, val):
8888
return val
8989

9090
def _repr_latex_(self, name=None, dist=None):
91+
"""Magic method name for IPython to use for LaTeX formatting."""
9192
return None
9293

94+
__latex__ = _repr_latex_
95+
9396

9497
def TensorType(dtype, shape, broadcastable=None):
9598
if broadcastable is None:

pymc3/model.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import collections
2+
import functools
3+
import itertools
24
import threading
35
import six
46

@@ -892,6 +894,14 @@ def flatten(self, vars=None, order=None, inputvar=None):
892894
flat_view = FlatView(inputvar, replacements, view)
893895
return flat_view
894896

897+
def _repr_latex_(self, name=None, dist=None):
898+
tex_vars = []
899+
for rv in itertools.chain(self.unobserved_RVs, self.observed_RVs):
900+
tex_vars.append(rv.__latex__())
901+
return u'$${}$$'.format('\\\\'.join([tex.strip('$') for tex in tex_vars if tex is not None]))
902+
903+
__latex__ = _repr_latex_
904+
895905

896906
def fn(outs, mode=None, model=None, *args, **kwargs):
897907
"""Compiles a Theano function which returns the values of `outs` and
@@ -1073,6 +1083,8 @@ def _repr_latex_(self, name=None, dist=None):
10731083
dist = self.distribution
10741084
return self.distribution._repr_latex_(name=name, dist=dist)
10751085

1086+
__latex__ = _repr_latex_
1087+
10761088
@property
10771089
def init_value(self):
10781090
"""Convenience attribute to return tag.test_value"""
@@ -1176,6 +1188,8 @@ def _repr_latex_(self, name=None, dist=None):
11761188
dist = self.distribution
11771189
return self.distribution._repr_latex_(name=name, dist=dist)
11781190

1191+
__latex__ = _repr_latex_
1192+
11791193
@property
11801194
def init_value(self):
11811195
"""Convenience attribute to return tag.test_value"""
@@ -1212,6 +1226,26 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
12121226
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)
12131227

12141228

1229+
def _walk_up_rv(rv):
1230+
"""Walk up theano graph to get inputs for deterministic RV."""
1231+
all_rvs = []
1232+
parents = list(itertools.chain(*[j.inputs for j in rv.get_parents()]))
1233+
if parents:
1234+
for parent in parents:
1235+
all_rvs.extend(_walk_up_rv(parent))
1236+
else:
1237+
if rv.name:
1238+
all_rvs.append(rv.name)
1239+
else:
1240+
all_rvs.append(r'\text{Constant}')
1241+
return all_rvs
1242+
1243+
1244+
def _latex_repr_rv(rv):
1245+
"""Make latex string for a Deterministic variable"""
1246+
return r'${} \sim \text{{Deterministic}}({})$'.format(rv.name, r', '.join(_walk_up_rv(rv)))
1247+
1248+
12151249
def Deterministic(name, var, model=None):
12161250
"""Create a named deterministic variable
12171251
@@ -1228,6 +1262,8 @@ def Deterministic(name, var, model=None):
12281262
var.name = model.name_for(name)
12291263
model.deterministics.append(var)
12301264
model.add_random_variable(var)
1265+
var._repr_latex_ = functools.partial(_latex_repr_rv, var)
1266+
var.__latex__ = var._repr_latex_
12311267
return var
12321268

12331269

@@ -1301,6 +1337,8 @@ def _repr_latex_(self, name=None, dist=None):
13011337
dist = self.distribution
13021338
return self.distribution._repr_latex_(name=name, dist=dist)
13031339

1340+
__latex__ = _repr_latex_
1341+
13041342
@property
13051343
def init_value(self):
13061344
"""Convenience attribute to return tag.test_value"""

pymc3/tests/test_distributions.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,17 @@
44

55
from .helpers import SeededTest, select_by_precision
66
from ..vartypes import continuous_types
7-
from ..model import Model, Point, Potential
7+
from ..model import Model, Point, Potential, Deterministic
88
from ..blocking import DictToVarBijection, DictToArrayBijection, ArrayOrdering
99
from ..distributions import (DensityDist, Categorical, Multinomial, VonMises, Dirichlet,
10-
MvStudentT, MvNormal, ZeroInflatedPoisson, GaussianRandomWalk,
10+
MvStudentT, MvNormal, ZeroInflatedPoisson,
1111
ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta,
12-
BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto, NormalMixture,
12+
BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto,
1313
InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
1414
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
1515
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
16-
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull, Gumbel,
17-
Interpolated, ZeroInflatedBinomial, HalfFlat)
16+
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull,
17+
Gumbel, Interpolated, ZeroInflatedBinomial, HalfFlat)
1818
from ..distributions import continuous
1919
from pymc3.theanof import floatX
2020
from numpy import array, inf, log, exp
@@ -929,27 +929,54 @@ def test_bound():
929929
assert rand >= 5 and rand <= 8
930930

931931

932-
def test_repr_latex_():
933-
with Model():
934-
x0 = Binomial('Discrete', p=.5, n=10)
935-
x1 = Normal('Continuous', mu=0., sd=1.)
936-
x2 = GaussianRandomWalk('Timeseries', mu=x1, sd=1., shape=2)
937-
x3 = MvStudentT('Multivariate', nu=5, mu=x2, Sigma=np.diag(np.ones(2)), shape=2)
938-
x4 = NormalMixture('Mixture', w=np.array([.5, .5]), mu=x3, sd=x0)
939-
940-
assert x0._repr_latex_() == '$Discrete \\sim \\text{Binomial}' \
941-
'(\\mathit{n}=10, \\mathit{p}=0.5)$'
942-
assert x1._repr_latex_() == '$Continuous \\sim \\text{Normal}' \
943-
'(\\mathit{mu}=0.0, \\mathit{sd}=1.0)$'
944-
assert x2._repr_latex_() == '$Timeseries \\sim \\text' \
945-
'{GaussianRandomWalk}(\\mathit{mu}=Continuous, ' \
946-
'\\mathit{sd}=1.0)$'
947-
assert x3._repr_latex_() == '$Multivariate \\sim \\text{MvStudentT}' \
948-
'(\\mathit{nu}=5, \\mathit{mu}=Timeseries, ' \
949-
'\\mathit{cov}=array)$'
950-
assert x4._repr_latex_() == '$Mixture \\sim \\text{NormalMixture}' \
951-
'(\\mathit{w}=array, \\mathit{mu}=Multivariate, ' \
952-
'\\mathit{sigma}=f(Discrete))$'
932+
class TestLatex(object):
933+
934+
def setup_class(self):
935+
# True parameter values
936+
alpha, sigma = 1, 1
937+
beta = [1, 2.5]
938+
939+
# Size of dataset
940+
size = 100
941+
942+
# Predictor variable
943+
X = np.random.normal(size=(size, 2)).dot(np.array([[1, 0], [0, 0.2]]))
944+
945+
# Simulate outcome variable
946+
Y = alpha + X.dot(beta) + np.random.randn(size)*sigma
947+
with Model() as self.model:
948+
# Priors for unknown model parameters
949+
alpha = Normal('alpha', mu=0, sd=10)
950+
b = Normal('beta', mu=0, sd=10, shape=(2,), observed=beta)
951+
sigma = HalfNormal('sigma', sd=1)
952+
953+
# Expected value of outcome
954+
mu = Deterministic('mu', alpha + tt.dot(X, b))
955+
956+
# Likelihood (sampling distribution) of observations
957+
Y_obs = Normal('Y_obs', mu=mu, sd=sigma, observed=Y)
958+
self.distributions = [alpha, sigma, mu, b, Y_obs]
959+
self.expected = (
960+
'$alpha \\sim \\text{Normal}(\\mathit{mu}=0, \\mathit{sd}=10.0)$',
961+
'$sigma \\sim \\text{HalfNormal}(\\mathit{sd}=1.0)$',
962+
'$mu \\sim \\text{Deterministic}(alpha, \\text{Constant}, beta)$',
963+
'$beta \\sim \\text{Normal}(\\mathit{mu}=0, \\mathit{sd}=10.0)$',
964+
'$Y_obs \\sim \\text{Normal}(\\mathit{mu}=mu, \\mathit{sd}=f(sigma))$'
965+
)
966+
967+
def test__repr_latex_(self):
968+
for distribution, tex in zip(self.distributions, self.expected):
969+
assert distribution._repr_latex_() == tex
970+
971+
model_tex = self.model._repr_latex_()
972+
973+
for tex in self.expected: # make sure each variable is in the model
974+
assert tex.strip('$') in model_tex
975+
976+
def test___latex__(self):
977+
for distribution, tex in zip(self.distributions, self.expected):
978+
assert distribution._repr_latex_() == distribution.__latex__()
979+
assert self.model._repr_latex_() == self.model.__latex__()
953980

954981

955982
def test_discrete_trafo():

0 commit comments

Comments
 (0)