Skip to content

Commit ea721e4

Browse files
authored
Fix latex repr of symbolic distributions (#6231)
* Fix latex repr of nested variables * Do not put \operatorname inside \text * Remove inner $ symbols from latex representations * remove them from the string borders since they always represent the math environment * fix tests to check the correct behavior * Add test for a full model that used to be rendered wrong * Use "\\operatorname{" instead of "operatorname" to determine if the latex command is used * Update tests to new distributions' notation
1 parent 9105d74 commit ea721e4

File tree

2 files changed

+64
-11
lines changed

2 files changed

+64
-11
lines changed

pymc/printing.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def str_for_dist(
5858

5959
if "latex" in formatting:
6060
if print_name is not None:
61-
print_name = r"\text{" + _latex_escape(dist.name) + "}"
61+
print_name = r"\text{" + _latex_escape(dist.name.strip("$")) + "}"
6262

6363
op_name = (
6464
dist.owner.op._print_name[1]
@@ -67,9 +67,11 @@ def str_for_dist(
6767
)
6868
if include_params:
6969
if print_name:
70-
return r"${} \sim {}({})$".format(print_name, op_name, ",~".join(dist_args))
70+
return r"${} \sim {}({})$".format(
71+
print_name, op_name, ",~".join([d.strip("$") for d in dist_args])
72+
)
7173
else:
72-
return r"${}({})$".format(op_name, ",~".join(dist_args))
74+
return r"${}({})$".format(op_name, ",~".join([d.strip("$") for d in dist_args]))
7375

7476
else:
7577
if print_name:
@@ -138,7 +140,7 @@ def str_for_potential_or_deterministic(
138140
LaTeX or plain, optionally with distribution parameter values included."""
139141
print_name = var.name if var.name is not None else "<unnamed>"
140142
if "latex" in formatting:
141-
print_name = r"\text{" + _latex_escape(print_name) + "}"
143+
print_name = r"\text{" + _latex_escape(print_name.strip("$")) + "}"
142144
if include_params:
143145
return rf"${print_name} \sim \operatorname{{{dist_name}}}({_str_for_expression(var, formatting=formatting)})$"
144146
else:
@@ -182,7 +184,7 @@ def _str_for_input_rv(var: Variable, formatting: str) -> str:
182184
else str_for_dist(var, formatting=formatting, include_params=True)
183185
)
184186
if "latex" in formatting:
185-
return r"\text{" + _latex_escape(_str) + "}"
187+
return _latex_text_format(_latex_escape(_str.strip("$")))
186188
else:
187189
return _str
188190

@@ -215,9 +217,20 @@ def _expand(x):
215217
names = [x.name for x in parents]
216218

217219
if "latex" in formatting:
218-
return r"f(" + ",~".join([r"\text{" + _latex_escape(n) + "}" for n in names]) + ")"
220+
return (
221+
r"f("
222+
+ ",~".join([_latex_text_format(_latex_escape(n.strip("$"))) for n in names])
223+
+ ")"
224+
)
225+
else:
226+
return r"f(" + ", ".join([n.strip("$") for n in names]) + ")"
227+
228+
229+
def _latex_text_format(text: str) -> str:
230+
if r"\operatorname{" in text:
231+
return text
219232
else:
220-
return r"f(" + ", ".join(names) + ")"
233+
return r"\text{" + text + "}"
221234

222235

223236
def _latex_escape(text: str) -> str:

pymc/tests/test_printing.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from pymc import Bernoulli, Censored, Mixture
3+
from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
44
from pymc.aesaraf import floatX
55
from pymc.distributions import (
66
Dirichlet,
@@ -130,12 +130,12 @@ def setup_class(self):
130130
r"$\text{beta} \sim \operatorname{N}(0,~10)$",
131131
r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
132132
r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
133-
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$",
133+
r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))$",
134134
r"$\text{w} \sim \operatorname{Dir}(\text{<constant>})$",
135135
(
136136
r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w},"
137-
r"~\text{\$\operatorname{MarginalMixture}(f(),~\text{\\$\operatorname{DiracDelta}(0)\\$},~\text{\\$\operatorname{Pois}(5)\\$})\$},"
138-
r"~\text{\$\operatorname{Censored}(\text{\\$\operatorname{Bern}(0.5)\\$},~-1,~1)\$})$"
137+
r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5)),"
138+
r"~\operatorname{Censored}(\operatorname{Bern}(0.5),~-1,~1))$"
139139
),
140140
r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
141141
r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
@@ -178,3 +178,43 @@ def test_str_repr(self):
178178
assert segment in model_text
179179
else:
180180
assert text in model_text
181+
182+
183+
def test_model_latex_repr_three_levels_model():
184+
with Model() as censored_model:
185+
mu = Normal("mu", 0.0, 5.0)
186+
sigma = HalfCauchy("sigma", 2.5)
187+
normal_dist = Normal.dist(mu=mu, sigma=sigma)
188+
censored_normal = Censored(
189+
"censored_normal", normal_dist, lower=-2.0, upper=2.0, observed=[1, 0, 0.5]
190+
)
191+
192+
latex_repr = censored_model.str_repr(formatting="latex")
193+
expected = [
194+
"$$",
195+
"\\begin{array}{rcl}",
196+
"\\text{mu} &\\sim & \\operatorname{N}(0,~5)\\\\\\text{sigma} &\\sim & "
197+
"\\operatorname{C^{+}}(0,~2.5)\\\\\\text{censored_normal} &\\sim & "
198+
"\\operatorname{Censored}(\\operatorname{N}(\\text{mu},~\\text{sigma}),~-2,~2)",
199+
"\\end{array}",
200+
"$$",
201+
]
202+
assert [line.strip() for line in latex_repr.split("\n")] == expected
203+
204+
205+
def test_model_latex_repr_mixture_model():
206+
with Model() as mix_model:
207+
w = Dirichlet("w", [1, 1])
208+
mix = Mixture("mix", w=w, comp_dists=[Normal.dist(0.0, 5.0), StudentT.dist(7.0)])
209+
210+
latex_repr = mix_model.str_repr(formatting="latex")
211+
expected = [
212+
"$$",
213+
"\\begin{array}{rcl}",
214+
"\\text{w} &\\sim & "
215+
"\\operatorname{Dir}(\\text{<constant>})\\\\\\text{mix} &\\sim & "
216+
"\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{N}(0,~5),~\\operatorname{StudentT}(7,~0,~1))",
217+
"\\end{array}",
218+
"$$",
219+
]
220+
assert [line.strip() for line in latex_repr.split("\n")] == expected

0 commit comments

Comments
 (0)