Skip to content

Commit 1b67b26

Browse files
authored
Escape latex, fix Flat distribution (#2485)
1 parent ed288ed commit 1b67b26

File tree

4 files changed

+42
-17
lines changed

4 files changed

+42
-17
lines changed

pymc3/distributions/continuous.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,7 @@ def logp(self, value):
201201
return tt.zeros_like(value)
202202

203203
def _repr_latex_(self, name=None, dist=None):
204-
if dist is None:
205-
dist = self
206-
return r'${} \sim \text{Flat}()$'
204+
return r'${} \sim \text{Flat}()$'.format(name)
207205

208206

209207
class HalfFlat(PositiveContinuous):
@@ -220,9 +218,7 @@ def logp(self, value):
220218
return bound(tt.zeros_like(value), value > 0)
221219

222220
def _repr_latex_(self, name=None, dist=None):
223-
if dist is None:
224-
dist = self
225-
return r'${} \sim \text{{HalfFlat}()$'
221+
return r'${} \sim \text{{HalfFlat}()$'.format(name)
226222

227223

228224
class Normal(Continuous):

pymc3/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .theanof import gradient, hessian, inputvars, generator
1818
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
1919
from .blocking import DictToArrayBijection, ArrayOrdering
20-
from .util import get_transformed_name
20+
from .util import get_transformed_name, escape_latex
2121

2222
__all__ = [
2323
'Model', 'Factor', 'compilef', 'fn', 'fastfn', 'modelcontext',
@@ -1081,7 +1081,7 @@ def _repr_latex_(self, name=None, dist=None):
10811081
name = self.name
10821082
if dist is None:
10831083
dist = self.distribution
1084-
return self.distribution._repr_latex_(name=name, dist=dist)
1084+
return self.distribution._repr_latex_(name=escape_latex(name), dist=dist)
10851085

10861086
__latex__ = _repr_latex_
10871087

@@ -1186,7 +1186,7 @@ def _repr_latex_(self, name=None, dist=None):
11861186
name = self.name
11871187
if dist is None:
11881188
dist = self.distribution
1189-
return self.distribution._repr_latex_(name=name, dist=dist)
1189+
return self.distribution._repr_latex_(name=escape_latex(name), dist=dist)
11901190

11911191
__latex__ = _repr_latex_
11921192

@@ -1335,7 +1335,7 @@ def _repr_latex_(self, name=None, dist=None):
13351335
name = self.name
13361336
if dist is None:
13371337
dist = self.distribution
1338-
return self.distribution._repr_latex_(name=name, dist=dist)
1338+
return self.distribution._repr_latex_(name=escape_latex(name), dist=dist)
13391339

13401340
__latex__ = _repr_latex_
13411341

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -957,11 +957,11 @@ def setup_class(self):
957957
Y_obs = Normal('Y_obs', mu=mu, sd=sigma, observed=Y)
958958
self.distributions = [alpha, sigma, mu, b, Y_obs]
959959
self.expected = (
960-
'$alpha \\sim \\text{Normal}(\\mathit{mu}=0, \\mathit{sd}=10.0)$',
961-
'$sigma \\sim \\text{HalfNormal}(\\mathit{sd}=1.0)$',
962-
'$mu \\sim \\text{Deterministic}(alpha, \\text{Constant}, beta)$',
963-
'$beta \\sim \\text{Normal}(\\mathit{mu}=0, \\mathit{sd}=10.0)$',
964-
'$Y_obs \\sim \\text{Normal}(\\mathit{mu}=mu, \\mathit{sd}=f(sigma))$'
960+
r'$alpha \sim \text{Normal}(\mathit{mu}=0, \mathit{sd}=10.0)$',
961+
r'$sigma \sim \text{HalfNormal}(\mathit{sd}=1.0)$',
962+
r'$mu \sim \text{Deterministic}(alpha, \text{Constant}, beta)$',
963+
r'$beta \sim \text{Normal}(\mathit{mu}=0, \mathit{sd}=10.0)$',
964+
r'$Y\_obs \sim \text{Normal}(\mathit{mu}=mu, \mathit{sd}=f(sigma))$'
965965
)
966966

967967
def test__repr_latex_(self):

pymc3/util.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,33 @@
1+
import re
2+
13
from numpy import asscalar
24

5+
LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)
6+
7+
8+
def escape_latex(strng):
9+
"""Consistently escape LaTeX special characters for _repr_latex_ in IPython
10+
11+
Implementation taken from the IPython magic `format_latex`
12+
13+
Example
14+
-------
15+
escape_latex('disease_rate') # 'disease\_rate'
16+
17+
Parameters
18+
----------
19+
strng : str
20+
string to escape LaTeX characters
21+
22+
Returns
23+
-------
24+
str
25+
A string with LaTeX escaped
26+
"""
27+
if strng is None:
28+
return u'None'
29+
return LATEX_ESCAPE_RE.sub(r'\\\1', strng)
30+
331

432
def get_transformed_name(name, transform):
533
"""
@@ -14,7 +42,7 @@ def get_transformed_name(name, transform):
1442
1543
Returns
1644
-------
17-
str
45+
str
1846
A string to use for the transformed variable
1947
"""
2048
return "{}_{}__".format(name, transform.name)
@@ -88,14 +116,15 @@ def get_variable_name(variable):
88116
try:
89117
names = [get_variable_name(item)
90118
for item in variable.get_parents()[0].inputs]
119+
# do not escape_latex these, since it is not idempotent
91120
return 'f(%s)' % ','.join([n for n in names if isinstance(n, str)])
92121
except IndexError:
93122
pass
94123
value = variable.eval()
95124
if not value.shape:
96125
return asscalar(value)
97126
return 'array'
98-
return name
127+
return escape_latex(name)
99128

100129

101130
def update_start_vals(a, b, model):

0 commit comments

Comments
 (0)