Skip to content

Commit 417bc8a

Browse files
committed
[BUG] Plot posterior function was a bit broken
1 parent f168be3 commit 417bc8a

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

examples/2-basic_geology/1-thickness_problem.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""
1+
"""
22
2.1 - Only Pyro
33
===============
44
@@ -148,7 +148,8 @@ def model(y_obs_list_):
148148
iteration = 99
149149
p.plot_posterior(
150150
prior_var=['$\\mu_{top}$', '$\\mu_{bottom}$'],
151-
like_var=['$\mu_{thickness}$', '$\sigma_{thickness}$'],
151+
like_var=['$\\mu_{top}$', '$\\mu_{bottom}$'],
152+
# like_var=['$\\mu_{thickness}$', r"y_{top}"],
152153
obs='y_{thickness}',
153154
iteration=iteration,
154155
marginal_kwargs={

examples/DEP/ch5_2_introduction_pymc3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def plot_geo_setting_well():
193193

194194
p.create_figure(figsize=(9, 5), joyplot=True, marginal=True, likelihood=True, n_samples=11)
195195

196-
p.plot_posterior(['$\mu$', '$\sigma$'], ['$\mu$', '$\sigma$'], '$y$',
196+
p.plot_posterior(
197+
['$\mu$', '$\sigma$'], ['$\mu$', '$\sigma$'], '$y$',
197198
marginal_kwargs={'plot_trace': False, 'credible_interval': .93, 'kind': 'kde'})
198199
plt.show()

gempy_probability/plot_posterior.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -521,12 +521,12 @@ def plot_joy(self, var_names: tuple = None, obs: Union[str, float] = None,
521521

522522
obs = data.observed_data[obs] if type(obs) is str else obs
523523

524-
# data = convert_to_dataset(data, group="posterior")
525524
coords = {}
526-
# var_names = _var_names(var_names, data)
527-
528525
plotters_in_posterior = list(xarray_var_iter(
529-
data=get_coords(data, coords),
526+
data=get_coords(
527+
convert_to_dataset(data, group="posterior"),
528+
coords
529+
),
530530
var_names=var_names,
531531
combined=True
532532
) )
@@ -538,8 +538,8 @@ def plot_joy(self, var_names: tuple = None, obs: Union[str, float] = None,
538538

539539
plotters = plotters_in_posterior + plotters_in_posterior_predictive
540540

541-
x = plotters[0][3].flatten()
542-
y = plotters[1][3].flatten()
541+
x = plotters[0][-1].flatten()
542+
y = plotters[1][-1].flatten()
543543

544544
n_data = x.shape[0]
545545
# This is the special case if n_samples is smaller than the number of bells to plot

0 commit comments

Comments
 (0)