|
1 | 1 | import numpy as np
|
2 | 2 |
|
3 |
| -from pymc import Bernoulli, Censored, Mixture |
| 3 | +from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT |
4 | 4 | from pymc.aesaraf import floatX
|
5 | 5 | from pymc.distributions import (
|
6 | 6 | Dirichlet,
|
@@ -130,12 +130,12 @@ def setup_class(self):
|
130 | 130 | r"$\text{beta} \sim \operatorname{N}(0,~10)$",
|
131 | 131 | r"$\text{Z} \sim \operatorname{N}(f(),~f())$",
|
132 | 132 | r"$\text{nb_with_p_n} \sim \operatorname{NB}(10,~\text{nbp})$",
|
133 |
| - r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\text{\$\operatorname{DiracDelta}(0)\$},~\text{\$\operatorname{Pois}(5)\$})$", |
| 133 | + r"$\text{zip} \sim \operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))$", |
134 | 134 | r"$\text{w} \sim \operatorname{Dir}(\text{<constant>})$",
|
135 | 135 | (
|
136 | 136 | r"$\text{nested_mix} \sim \operatorname{MarginalMixture}(\text{w},"
|
137 |
| - r"~\text{\$\operatorname{MarginalMixture}(f(),~\text{\\$\operatorname{DiracDelta}(0)\\$},~\text{\\$\operatorname{Pois}(5)\\$})\$}," |
138 |
| - r"~\text{\$\operatorname{Censored}(\text{\\$\operatorname{Bern}(0.5)\\$},~-1,~1)\$})$" |
| 137 | + r"~\operatorname{MarginalMixture}(f(),~\operatorname{DiracDelta}(0),~\operatorname{Pois}(5))," |
| 138 | + r"~\operatorname{Censored}(\operatorname{Bern}(0.5),~-1,~1))$" |
139 | 139 | ),
|
140 | 140 | r"$\text{Y_obs} \sim \operatorname{N}(\text{mu},~\text{sigma})$",
|
141 | 141 | r"$\text{pot} \sim \operatorname{Potential}(f(\text{beta},~\text{alpha}))$",
|
@@ -178,3 +178,43 @@ def test_str_repr(self):
|
178 | 178 | assert segment in model_text
|
179 | 179 | else:
|
180 | 180 | assert text in model_text
|
| 181 | + |
| 182 | + |
| 183 | +def test_model_latex_repr_three_levels_model(): |
| 184 | + with Model() as censored_model: |
| 185 | + mu = Normal("mu", 0.0, 5.0) |
| 186 | + sigma = HalfCauchy("sigma", 2.5) |
| 187 | + normal_dist = Normal.dist(mu=mu, sigma=sigma) |
| 188 | + censored_normal = Censored( |
| 189 | + "censored_normal", normal_dist, lower=-2.0, upper=2.0, observed=[1, 0, 0.5] |
| 190 | + ) |
| 191 | + |
| 192 | + latex_repr = censored_model.str_repr(formatting="latex") |
| 193 | + expected = [ |
| 194 | + "$$", |
| 195 | + "\\begin{array}{rcl}", |
| 196 | + "\\text{mu} &\\sim & \\operatorname{N}(0,~5)\\\\\\text{sigma} &\\sim & " |
| 197 | + "\\operatorname{C^{+}}(0,~2.5)\\\\\\text{censored_normal} &\\sim & " |
| 198 | + "\\operatorname{Censored}(\\operatorname{N}(\\text{mu},~\\text{sigma}),~-2,~2)", |
| 199 | + "\\end{array}", |
| 200 | + "$$", |
| 201 | + ] |
| 202 | + assert [line.strip() for line in latex_repr.split("\n")] == expected |
| 203 | + |
| 204 | + |
| 205 | +def test_model_latex_repr_mixture_model(): |
| 206 | + with Model() as mix_model: |
| 207 | + w = Dirichlet("w", [1, 1]) |
| 208 | + mix = Mixture("mix", w=w, comp_dists=[Normal.dist(0.0, 5.0), StudentT.dist(7.0)]) |
| 209 | + |
| 210 | + latex_repr = mix_model.str_repr(formatting="latex") |
| 211 | + expected = [ |
| 212 | + "$$", |
| 213 | + "\\begin{array}{rcl}", |
| 214 | + "\\text{w} &\\sim & " |
| 215 | + "\\operatorname{Dir}(\\text{<constant>})\\\\\\text{mix} &\\sim & " |
| 216 | + "\\operatorname{MarginalMixture}(\\text{w},~\\operatorname{N}(0,~5),~\\operatorname{StudentT}(7,~0,~1))", |
| 217 | + "\\end{array}", |
| 218 | + "$$", |
| 219 | + ] |
| 220 | + assert [line.strip() for line in latex_repr.split("\n")] == expected |
0 commit comments