Skip to content

Commit 50240a0

Browse files
committed
make nicer overlap plots
Signed-off-by: Nathaniel <[email protected]>
1 parent 0842735 commit 50240a0

File tree

2 files changed

+137
-49
lines changed

2 files changed

+137
-49
lines changed

causalpy/pymc_experiments.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,25 +1656,57 @@ def plot_ATE(self, idata=None, method=None, prop_draws=100, ate_draws=300):
16561656
if method is None:
16571657
method = self.weighting_scheme
16581658

1659-
def plot_weights(bins, top0, top1, ax):
1659+
def plot_weights(bins, top0, top1, ax, color="population"):
1660+
colors_dict = {
1661+
"population": ["red", "blue", 0.9],
1662+
"pseudo_population": ["purple", "purple", 0.1],
1663+
}
1664+
16601665
ax.axhline(0, c="gray", linewidth=1)
16611666
bars0 = ax.bar(
1662-
bins[:-1] + 0.025, top0, width=0.04, facecolor="red", alpha=0.3
1667+
bins[:-1] + 0.025,
1668+
top0,
1669+
width=0.04,
1670+
facecolor=colors_dict[color][0],
1671+
alpha=colors_dict[color][2],
16631672
)
16641673
bars1 = ax.bar(
1665-
bins[:-1] + 0.025, -top1, width=0.04, facecolor="blue", alpha=0.3
1674+
bins[:-1] + 0.025,
1675+
-top1,
1676+
width=0.04,
1677+
facecolor=colors_dict[color][1],
1678+
alpha=colors_dict[color][2],
16661679
)
16671680

16681681
for bars in (bars0, bars1):
16691682
for bar in bars:
16701683
bar.set_edgecolor("black")
16711684

1672-
def make_hists(idata, i, axs):
1685+
def make_hists(idata, i, axs, method=method):
16731686
p_i = az.extract(idata)["p"][:, i].values
1687+
if method == "raw":
1688+
weight0 = 1 / (1 - p_i[self.t.flatten() == 0])
1689+
weight1 = 1 / (p_i[self.t.flatten() == 1])
1690+
elif method == "overlap":
1691+
t = self.t.flatten()
1692+
weight1 = (1 - p_i[t == 1]) * t[t == 1]
1693+
weight0 = p_i[t == 0] * (1 - t[t == 0])
1694+
else:
1695+
t = self.t.flatten()
1696+
p_of_t = np.mean(t)
1697+
weight1 = p_of_t / p_i[t == 1]
1698+
weight0 = (1 - p_of_t) / (1 - p_i[t == 0])
16741699
bins = np.arange(0.025, 0.99, 0.005)
16751700
top0, _ = np.histogram(p_i[self.t.flatten() == 0], bins=bins)
16761701
top1, _ = np.histogram(p_i[self.t.flatten() == 1], bins=bins)
16771702
plot_weights(bins, top0, top1, axs[0])
1703+
top0, _ = np.histogram(
1704+
p_i[self.t.flatten() == 0], bins=bins, weights=weight0
1705+
)
1706+
top1, _ = np.histogram(
1707+
p_i[self.t.flatten() == 1], bins=bins, weights=weight1
1708+
)
1709+
plot_weights(bins, top0, top1, axs[0], color="pseudo_population")
16781710

16791711
mosaic = """AAAAAA
16801712
BBBBCC"""
@@ -1690,6 +1722,7 @@ def make_hists(idata, i, axs):
16901722
axs[0].set_title(
16911723
"Draws from the Posterior \n Propensity Scores Distribution", fontsize=20
16921724
)
1725+
axs[0].legend()
16931726

16941727
[make_hists(idata, i, axs) for i in range(prop_draws)]
16951728
ate_df = pd.DataFrame(

docs/source/notebooks/inv_prop_pymc.ipynb

Lines changed: 100 additions & 45 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)