Skip to content

Commit b1c417b

Browse files
committed
addressing Juan's comments
1 parent c98cfd9 commit b1c417b

File tree

8 files changed

+256
-277
lines changed

8 files changed

+256
-277
lines changed

causalpy/plot_utils.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
from typing import Any, Dict, Optional, Union
1+
from typing import Any, Dict, Optional, Tuple, Union
22

33
import arviz as az
44
import matplotlib.pyplot as plt
55
import numpy as np
66
import pandas as pd
77
import xarray as xr
88
from matplotlib.collections import PolyCollection
9+
from matplotlib.lines import Line2D
910

1011

1112
def plot_xY(
@@ -14,9 +15,8 @@ def plot_xY(
1415
ax: plt.Axes,
1516
plot_hdi_kwargs: Optional[Dict[str, Any]] = None,
1617
hdi_prob: float = 0.94,
17-
label: Optional[str] = "",
18-
include_label: bool = True,
19-
):
18+
label: Union[str, None] = None,
19+
) -> Tuple[Line2D, PolyCollection]:
2020
"""Utility function to plot HDI intervals."""
2121

2222
if plot_hdi_kwargs is None:
@@ -27,33 +27,23 @@ def plot_xY(
2727
Y.mean(dim=["chain", "draw"]),
2828
ls="-",
2929
**plot_hdi_kwargs,
30-
label=f"{label}" if include_label else None,
30+
label=f"{label}",
3131
)
3232
ax_hdi = az.plot_hdi(
3333
x,
3434
Y,
3535
hdi_prob=hdi_prob,
3636
fill_kwargs={
3737
"alpha": 0.25,
38-
"label": " ", # f"{hdi_prob*100}% HDI" if include_label else None,
38+
"label": " ",
3939
},
4040
smooth=False,
4141
ax=ax,
4242
**plot_hdi_kwargs,
4343
)
44-
# Return handle to patch.
45-
# We get a list of the childen of the axis
46-
# Filter for just the PolyCollection objects
47-
# Take the last one
44+
# Return handle to patch. We get a list of the childen of the axis. Filter for just
45+
# the PolyCollection objects. Take the last one.
4846
h_patch = list(
4947
filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children())
5048
)[-1]
51-
52-
# if include_label:
53-
# handles, labels = ax.get_legend_handles_labels()
54-
# ax.legend(
55-
# handles=[(h1, h2) for h1, h2 in zip(handles[::2], handles[1::2])],
56-
# # labels=[l1 + " + " + l2 for l1, l2 in zip(labels[::2], labels[1::2])],
57-
# labels=[l1 for l1 in labels[::2]],
58-
# )
59-
return h_line, h_patch
49+
return (h_line, h_patch)

causalpy/pymc_experiments.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ def plot(self):
122122
self.datapre.index,
123123
self.pre_pred["posterior_predictive"].mu,
124124
ax=ax[0],
125-
include_label=False,
126125
plot_hdi_kwargs={"color": "C0"},
127126
)
128127
handles = [(h_line, h_patch)]
@@ -137,7 +136,6 @@ def plot(self):
137136
self.datapost.index,
138137
self.post_pred["posterior_predictive"].mu,
139138
ax=ax[0],
140-
include_label=False,
141139
# label="Synthetic control",
142140
plot_hdi_kwargs={"color": "C1"},
143141
)
@@ -152,9 +150,8 @@ def plot(self):
152150
self.post_pred, group="posterior_predictive", var_names="mu"
153151
).mean("sample"),
154152
y2=np.squeeze(self.post_y),
155-
color="C2",
153+
color="C0",
156154
alpha=0.25,
157-
# label="Causal impact",
158155
)
159156
handles.append(h)
160157
labels.append("Causal impact")
@@ -171,21 +168,19 @@ def plot(self):
171168
self.datapre.index,
172169
self.pre_impact,
173170
ax=ax[1],
174-
include_label=False,
175171
plot_hdi_kwargs={"color": "C0"},
176172
)
177173
plot_xY(
178174
self.datapost.index,
179175
self.post_impact,
180176
ax=ax[1],
181-
include_label=False,
182177
plot_hdi_kwargs={"color": "C1"},
183178
)
184179
ax[1].axhline(y=0, c="k")
185180
ax[1].fill_between(
186181
self.datapost.index,
187182
y1=self.post_impact.mean(["chain", "draw"]),
188-
color="C2",
183+
color="C0",
189184
alpha=0.25,
190185
label="Causal impact",
191186
)
@@ -197,7 +192,6 @@ def plot(self):
197192
self.datapost.index,
198193
self.post_impact_cumulative,
199194
ax=ax[2],
200-
include_label=False,
201195
plot_hdi_kwargs={"color": "C1"},
202196
)
203197
ax[2].axhline(y=0, c="k")
@@ -209,7 +203,6 @@ def plot(self):
209203
ls="-",
210204
lw=3,
211205
color="r",
212-
# label="Treatment time",
213206
)
214207

215208
ax[0].legend(
@@ -434,7 +427,7 @@ def plot(self):
434427
widths=0.2,
435428
)
436429
for pc in parts["bodies"]:
437-
pc.set_facecolor("C2")
430+
pc.set_facecolor("C0")
438431
pc.set_edgecolor("None")
439432
pc.set_alpha(0.5)
440433
else:

docs/notebooks/did_pymc_banks.ipynb

Lines changed: 22 additions & 24 deletions
Large diffs are not rendered by default.

docs/notebooks/geolift1.ipynb

Lines changed: 13 additions & 15 deletions
Large diffs are not rendered by default.

docs/notebooks/its_covid.ipynb

Lines changed: 4 additions & 4 deletions
Large diffs are not rendered by default.

docs/notebooks/sc2_pymc.ipynb

Lines changed: 33 additions & 33 deletions
Large diffs are not rendered by default.

docs/notebooks/sc_pymc.ipynb

Lines changed: 27 additions & 27 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)